diff --git a/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..67d515697cbe4b43edb18dbdc4cf0270ebf13fb2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi @@ -0,0 +1,4 @@ +from . import compiled_autograd, eval_frame, guards # noqa: F401 + +def strip_function_call(name: str) -> str: ... +def is_valid_var_name(name: str) -> bool | int: ... diff --git a/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi new file mode 100644 index 0000000000000000000000000000000000000000..321a99fc709b4758beb212ce599e5dd700e7cdaa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi @@ -0,0 +1,13 @@ +from typing import Callable + +from torch import Tensor +from torch._dynamo.compiled_autograd import AutogradCompilerInstance + +def set_autograd_compiler( + autograd_compiler: Callable[[], AutogradCompilerInstance] | None, + dynamic: bool, +) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ... +def clear_cache() -> None: ... +def is_cache_empty() -> bool: ... +def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ... +def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ... diff --git a/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi new file mode 100644 index 0000000000000000000000000000000000000000..05dde69b0470cfc2ad1f6931f7f5fe0f8f973536 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi @@ -0,0 +1,71 @@ +import enum +import types +from typing import Optional, overload + +from torch._dynamo.types import ( + DynamoCallback, + DynamoGuardCompleteHook, + DynamoGuardHook, + GuardFn, +) + +def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... +def set_skip_guard_eval_unsafe(value: bool) -> bool: ... +def get_eval_frame_callback() -> DynamoCallback: ... +def reset_code(code: types.CodeType) -> None: ... +def unsupported(obj1: object, obj2: object) -> object: ... +def set_code_exec_strategy( + code: types.CodeType, strategy: _FrameExecStrategy +) -> None: ... +def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_guard_complete_hook( + hook: Optional[DynamoGuardCompleteHook], +) -> Optional[DynamoGuardCompleteHook]: ... +def raise_sigtrap() -> None: ... + +class _CacheEntry: + def check_fn(self, *args: object, **kwargs: object) -> bool: ... + code: types.CodeType + next: _CacheEntry | None + +class _ExtraState: + def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ... + +class _FrameAction(enum.IntEnum): + DEFAULT = 0 + SKIP = 1 + RUN_ONLY = 2 + +class _FrameExecStrategy: + cur_action: _FrameAction + recursive_action: _FrameAction + + @overload + def __init__(self) -> None: ... + @overload + def __init__( + self, cur_action: _FrameAction, recursive_action: _FrameAction + ) -> None: ... + +# This is an object that encapsulates the Python FrameType, and exposes +# properties Dynamo cares about for a frame. +class _PyInterpreterFrame: + f_code: types.CodeType + f_locals: dict[str, object] + f_globals: dict[str, object] + f_builtins: dict[str, object] + f_lasti: int + f_lineo: int + f_back: types.FrameType + # A tuple containing cell objects captured by this frame. + closure: tuple[types.CellType] + +def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... + +py_opcode_caches: list[int] + +def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... +def _load_precompile_entry( + code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType +) -> None: ... +def _reset_precompile_entries(code: types.CodeType) -> None: ... diff --git a/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi new file mode 100644 index 0000000000000000000000000000000000000000..67549d36fd4a0868c0f026366e3f4ebff9a73635 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi @@ -0,0 +1,191 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable + +import torch + +class GlobalStateGuard: + def check(self) -> bool: ... + def reason(self) -> str: ... + +class LeafGuard: ... +class GuardDebugInfo: ... + +class GuardManager: + def check(self, value) -> bool: ... + def check_verbose(self, value) -> GuardDebugInfo: ... + + # Accessors + def globals_dict_manager( + self, + f_globals: dict[str, Any], + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def framelocals_manager( + self, + key: tuple[str, int], + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def dict_getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def global_weakref_manager( + self, + global_name: str, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def type_manager( + self, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def getattr_manager( + self, + attr: str, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tensor_property_size_manager( + self, + idx: int, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tensor_property_shape_manager( + self, + idx: int, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tensor_property_storage_offset_manager( + self, + idx: None, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def indexed_manager( + self, + idx: int, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def lambda_manager( + self, + python_lambda, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + + # Leaf guards + def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ... + def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ... + def add_equals_match_guard( + self, + equals_val, + verbose_code_parts: list[str], + ) -> None: ... + def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... + def add_torch_function_mode_stack_guard( + self, initial_stack, verbose_code_parts: list[str] + ) -> None: ... + def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ... + +class RootGuardManager(GuardManager): + def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ... + def add_epilogue_lambda_guard( + self, + guard: LeafGuard, + verbose_code_parts: list[str], + ) -> None: ... + def clone_manager( + self, clone_filter_fn: Callable[[GuardManager], bool] + ) -> RootGuardManager: ... + +class DictGuardManager(GuardManager): + def get_key_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def get_value_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + +def install_object_aliasing_guard( + guard_managers: list[GuardManager], + tensor_names: list[str], + verbose_code_parts: list[str], +): ... +def install_no_tensor_aliasing_guard( + guard_managers: list[GuardManager], + tensor_names: list[str], + verbose_code_parts: list[str], +): ... +def install_storage_overlapping_guard( + overlapping_guard_managers: list[GuardManager], + non_overlapping_guard_managers: list[GuardManager], + verbose_code_parts: list[str], +): ... +def install_symbolic_shape_guard( + guard_managers: list[GuardManager], + nargs_int: int, + nargs_float: int, + py_addr: int, + py_addr_keep_alive: Any, + verbose_code_parts: list[str], +): ... +def profile_guard_manager( + guard_manager: GuardManager, + f_locals: dict[str, Any], + n_iters: int, +) -> float: ... + +class TensorGuards: + def __init__( + self, + *, + dynamic_dims_sizes: list[torch.SymInt | None] | None = None, + dynamic_dims_strides: list[torch.SymInt | None] | None = None, + ) -> None: ... + def check(self, *args) -> bool: ... + def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ... + +def assert_size_stride( + item: torch.Tensor, + size: torch.types._size, + stride: torch.types._size, + op_name: str | None = None, +): ... +def assert_alignment( + item: torch.Tensor, + alignment: int, + op_name: str | None = None, +): ... +def check_obj_id(obj: object, expected: int) -> bool: ... +def check_type_id(obj: object, expected: int) -> bool: ... +def dict_version(d: dict[Any, Any]) -> int: ... +def compute_overlapping_tensors( + tensors: list[torch.Tensor], symbolic: bool = True +) -> set[int]: ... diff --git a/.venv/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi b/.venv/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..039f9c22eea620bc9675d233684df72c7ac4471c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi @@ -0,0 +1,9 @@ +# Defined in torch/csrc/export/pybind.cpp +class CppExportedProgram: ... + +def deserialize_exported_program( + serialized_program: str, +) -> CppExportedProgram: ... +def serialize_exported_program( + cpp_exported_program: CppExportedProgram, +) -> str: ... diff --git a/.venv/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi b/.venv/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi new file mode 100644 index 0000000000000000000000000000000000000000..87e356453bcf0a6ac41e31a998b4d33b711dedb0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi @@ -0,0 +1,22 @@ +# Defined in torch/csrc/export/pt2_archive_constants.h + +ARCHIVE_ROOT_NAME: str = ... +ARCHIVE_FORMAT_PATH: str = ... +ARCHIVE_FORMAT_VALUE: str = ... +ARCHIVE_VERSION_PATH: str = ... +ARCHIVE_VERSION_VALUE: str = ... +MODELS_DIR: str = ... +MODELS_FILENAME_FORMAT: str = ... +AOTINDUCTOR_DIR: str = ... +MTIA_DIR: str = ... +WEIGHTS_DIR: str = ... +WEIGHT_FILENAME_PREFIX: str = ... +CONSTANTS_DIR: str = ... +TENSOR_CONSTANT_FILENAME_PREFIX: str = ... +CUSTOM_OBJ_FILENAME_PREFIX: str = ... +SAMPLE_INPUTS_DIR: str = ... +SAMPLE_INPUTS_FILENAME_FORMAT: str = ... +EXTRA_DIR: str = ... +MODULE_INFO_PATH: str = ... +XL_MODEL_WEIGHTS_DIR: str = ... +XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ... diff --git a/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f2fb15ebfbe938f9f3df782098cf62344ebded9 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b936913894af5be6ebb2777d7e08b26edaebf4c6 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b4ecd3729ccf01c13b35b39cb90c4bfaf101176 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a53e1fe7077b9ec78a155bb157845c80fae83470 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed58cc398dac4638ec9bed65e9f7489170cdc6ff Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3cd40f9148aa8c01e5d3348f1ddead66c3f2271 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d85924f68591b42f2ba91cacba435142d1145c7 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7439c22d66882d058e617edb85bc4407cfd742a9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py @@ -0,0 +1,35 @@ +# We are exposing all subpackages to the end-user. +# Because of possible inter-dependency, we want to avoid +# the cyclic imports, thus implementing lazy version +# as per https://peps.python.org/pep-0562/ + +from typing import TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from types import ModuleType + + from torch.ao.nn import ( # noqa: TC004 + intrinsic as intrinsic, + qat as qat, + quantizable as quantizable, + quantized as quantized, + sparse as sparse, + ) + + +__all__ = [ + "intrinsic", + "qat", + "quantizable", + "quantized", + "sparse", +] + + +def __getattr__(name: str) -> "ModuleType": + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08078ef9cde821c941e02ae55dddd2f9840f4bcd Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80ba84a84251db6229c38b5f2c48b233fe594fbb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py @@ -0,0 +1,41 @@ +import types + +from .modules import * # noqa: F403 +from .modules.fused import _FusedModule # noqa: F403 + + +# # Subpackages +# from . import qat # noqa: F403 +# from . import quantized # noqa: F403 + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] + + +# We are exposing all subpackages to the end-user. +# Because of possible inter-dependency, we want to avoid +# the cyclic imports, thus implementing lazy version +# as per https://peps.python.org/pep-0562/ +def __getattr__(name: str) -> types.ModuleType: + if name in __all__: + import importlib + + return importlib.import_module("." + name, __name__) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8a44d2ad9f7169bc64bdab18a98a00ef1caff91 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..132137b7357378fe29ef9a63310a554725aea86a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py @@ -0,0 +1,41 @@ +from .fused import ( # noqa: F401 + _FusedModule, + BNReLU2d, + BNReLU3d, + ConvAdd2d, + ConvAddReLU2d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + LinearBn1d, + LinearLeakyReLU, + LinearReLU, + LinearTanh, +) + + +__all__ = [ + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..998897aa54583056033568a42587aefa479c3c66 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1627101409f5cb46c1a3a45779a45fb600d8222 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5b9c26fdd0045a17d6d435ddf7e932a558988d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py @@ -0,0 +1,287 @@ +# mypy: allow-untyped-defs +import torch +from torch.nn import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + Conv1d, + Conv2d, + Conv3d, + Linear, + ReLU, +) +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "ConvBn1d", + "ConvBn2d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "BNReLU2d", + "BNReLU3d", + "LinearBn1d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] + + +# Used for identifying intrinsic modules used in quantization +class _FusedModule(torch.nn.Sequential): + pass + + +class ConvReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv1d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class ConvReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class ConvReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, relu): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, relu) + + +class LinearReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, relu): + assert ( + type_before_parametrizations(linear) == Linear + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(linear, relu) + + +class ConvBn1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(bn) == BatchNorm1d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBn2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(bn) == BatchNorm2d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBnReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv1d + and type_before_parametrizations(bn) == BatchNorm1d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class ConvBnReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv2d + and type_before_parametrizations(bn) == BatchNorm2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class ConvBn3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(bn) == BatchNorm3d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(conv, bn) + + +class ConvBnReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, bn, relu): + assert ( + type_before_parametrizations(conv) == Conv3d + and type_before_parametrizations(bn) == BatchNorm3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(conv)}" + f"{type_before_parametrizations(bn)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(conv, bn, relu) + + +class BNReLU2d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, batch_norm, relu): + assert ( + type_before_parametrizations(batch_norm) == BatchNorm2d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(batch_norm, relu) + + +class BNReLU3d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, batch_norm, relu): + assert ( + type_before_parametrizations(batch_norm) == BatchNorm3d + and type_before_parametrizations(relu) == ReLU + ), ( + f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}" + f"{type_before_parametrizations(relu)}" + ) + super().__init__(batch_norm, relu) + + +class LinearBn1d(_FusedModule): + r"""This is a sequential container which calls the Linear and BatchNorm1d modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, bn): + assert ( + type_before_parametrizations(linear) == Linear + and type_before_parametrizations(bn) == BatchNorm1d + ), ( + f"Incorrect types for input modules{type_before_parametrizations(linear)}" + f"{type_before_parametrizations(bn)}" + ) + super().__init__(linear, bn) + + +class LinearLeakyReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and LeakyReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, leaky_relu): + assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, ( + f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}" + ) + super().__init__(linear, leaky_relu) + + +class LinearTanh(_FusedModule): + r"""This is a sequential container which calls the Linear and Tanh modules. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, linear, tanh): + assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, ( + f"Incorrect types for input modules{type(linear)}{type(tanh)}" + ) + super().__init__(linear, tanh) + + +class ConvAdd2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d modules with extra Add. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, add): + super().__init__(conv) + self.add = add + + def forward(self, x1, x2): # type: ignore[override] + return self.add(self[0](x1), x2) + + +class ConvAddReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d, add, Relu. + During quantization this will be replaced with the corresponding fused module.""" + + def __init__(self, conv, add, relu): + super().__init__(conv) + self.add = add + self.relu = relu + + def forward(self, x1, x2): # type: ignore[override] + return self.relu(self.add(self[0](x1), x2)) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b793ef0f9954c0f1f040ee57871e0e27bbfc4c2 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18534bbc588e7480ac6529c6648c5976eadaea3a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py @@ -0,0 +1,32 @@ +from .conv_fused import ( + ConvBn1d, + ConvBn2d, + ConvBn3d, + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvReLU1d, + ConvReLU2d, + ConvReLU3d, + freeze_bn_stats, + update_bn_stats, +) +from .linear_fused import LinearBn1d +from .linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", + "LinearBn1d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fdeacce4be42dbc7af0e44d46368295bb114bc7 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..945304d67374eb29ffb4c4384461b6bb256f2a65 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d08b46caac6a38433abeac973f13b21cd30441 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b67e5e1ceb90ec717fa5cf18bbbcc5422d502016 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..6671e317b6b02ecaefaa2c78fbef39b77faad912 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -0,0 +1,1064 @@ +# mypy: allow-untyped-defs +import math +from typing import ClassVar, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.qat as nnqat +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.modules.utils import _pair, _single, _triple +from torch.nn.parameter import Parameter +from torch.nn.utils import fuse_conv_bn_weights + + +__all__ = [ + "ConvBn1d", + "ConvBnReLU1d", + "ConvReLU1d", + "ConvBn2d", + "ConvBnReLU2d", + "ConvReLU2d", + "ConvBn3d", + "ConvBnReLU3d", + "ConvReLU3d", + "update_bn_stats", + "freeze_bn_stats", +] +_BN_CLASS_MAP = { + 1: nn.BatchNorm1d, + 2: nn.BatchNorm2d, + 3: nn.BatchNorm3d, +} + + +class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): + _version = 2 + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + # BatchNormNd args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + dim=2, + ): + nn.modules.conv._ConvNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + False, + padding_mode, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.freeze_bn = freeze_bn if self.training else True + self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True) + self.weight_fake_quant = self.qconfig.weight() + if bias: + self.bias = Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + self.reset_bn_parameters() + + # this needs to be called after reset_bn_parameters, + # as they modify the same state + if self.training: + if freeze_bn: + self.freeze_bn_stats() + else: + self.update_bn_stats() + else: + self.freeze_bn_stats() + + self._enable_slow_path_for_better_numerical_stability = False + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def reset_bn_parameters(self): + self.bn.reset_running_stats() + init.uniform_(self.bn.weight) + init.zeros_(self.bn.bias) + # note: below is actually for conv, not BN + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def reset_parameters(self): + super().reset_parameters() + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def _forward(self, input): + if self._enable_slow_path_for_better_numerical_stability: + return self._forward_slow(input) + return self._forward_approximate(input) + + def _forward_approximate(self, input): + """Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std + """ + assert self.bn.running_var is not None + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # using zero bias here since the bias for original conv + # will be added later + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) + else: + zero_bias = torch.zeros( + self.out_channels, device=scaled_weight.device, dtype=input.dtype + ) + conv = self._conv_forward(input, scaled_weight, zero_bias) + conv_orig = conv / scale_factor.reshape(bias_shape) + if self.bias is not None: + conv_orig = conv_orig + self.bias.reshape(bias_shape) + conv = self.bn(conv_orig) + return conv + + def _forward_slow(self, input): + """ + A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf + It requires two forward passes but handles the case bn.weight == 0 + + Conv: Y = WX + B_c + Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c + + Batch statistics: + mean_Y = Y.mean() + = Y0.mean() + B_c + var_Y = (Y - mean_Y)^2.mean() + = (Y0 - Y0.mean())^2.mean() + BN (r: bn.weight, beta: bn.bias): + Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta + = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta + + Fused Conv BN training (std_Y = sqrt(var_Y + eps)): + Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta + = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta + + Fused Conv BN inference (running_std = sqrt(running_var + eps)): + Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta + + QAT with fused conv bn: + Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta + = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta + Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta + """ + + assert self.bn.running_var is not None + assert self.bn.running_mean is not None + + # using zero bias here since the bias for original conv + # will be added later + zero_bias = torch.zeros( + self.out_channels, device=self.weight.device, dtype=input.dtype + ) + + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + + if self.bn.training: + # needed to compute batch mean/std + conv_out = self._conv_forward(input, self.weight, zero_bias) + # update bn statistics + with torch.no_grad(): + conv_out_bias = ( + conv_out + if self.bias is None + else conv_out + self.bias.reshape(bias_shape) + ) + self.bn(conv_out_bias) + + # fused conv + bn without bias using bn running statistics + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # fused conv without bias for inference: (r * W / running_std) * X + conv_bn = self._conv_forward(input, scaled_weight, zero_bias) + + avg_dims = [0] + list(range(2, len(self.weight.shape))) + batch_mean = conv_out.mean(avg_dims) + batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean( + avg_dims + ) + batch_std = torch.sqrt(batch_var + self.bn.eps) + + # scale to use batch std in training mode + # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y) + unscale_factor = running_std / batch_std + conv_bn *= unscale_factor.reshape(bias_shape) + + fused_mean = batch_mean + fused_std = batch_std + else: + # fused conv + bn without bias using bn running statistics + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + # fused conv without bias for inference: (r * W / running_std) * X + conv_bn = self._conv_forward(input, scaled_weight, zero_bias) + + fused_mean = self.bn.running_mean - ( + self.bias if self.bias is not None else 0 + ) + fused_std = running_std + + # fused bias = beta - r * mean / std + fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std + conv_bn += fused_bias.reshape(bias_shape) + + # HACK to let conv bias participate in loss to avoid DDP error (parameters + # were not used in producing loss) + if self.bias is not None: + conv_bn += (self.bias - self.bias).reshape(bias_shape) + + return conv_bn + + def extra_repr(self): + # TODO(jerryzh): extend + return super().extra_repr() + + def forward(self, input): + return self._forward(input) + + def train(self, mode=True): + """ + Batchnorm's training behavior is using the self.training flag. Prevent + changing it if BN is frozen. This makes sure that calling `model.train()` + on a model with a frozen BN will behave properly. + """ + self.training = mode + if not self.freeze_bn: + for module in self.children(): + module.train(mode) + return self + + # ===== Serialization version history ===== + # + # Version 1/None + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- gamma : Tensor + # |--- beta : Tensor + # |--- running_mean : Tensor + # |--- running_var : Tensor + # |--- num_batches_tracked : Tensor + # + # Version 2 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- bn : Module + # |--- weight : Tensor (moved from v1.self.gamma) + # |--- bias : Tensor (moved from v1.self.beta) + # |--- running_mean : Tensor (moved from v1.self.running_mean) + # |--- running_var : Tensor (moved from v1.self.running_var) + # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version == 1: + # BN related parameters and buffers were moved into the BN module for v2 + v2_to_v1_names = { + "bn.weight": "gamma", + "bn.bias": "beta", + "bn.running_mean": "running_mean", + "bn.running_var": "running_var", + "bn.num_batches_tracked": "num_batches_tracked", + } + for v2_name, v1_name in v2_to_v1_names.items(): + if prefix + v1_name in state_dict: + state_dict[prefix + v2_name] = state_dict[prefix + v1_name] + state_dict.pop(prefix + v1_name) + elif prefix + v2_name in state_dict: + # there was a brief period where forward compatibility + # for this module was broken (between + # https://github.com/pytorch/pytorch/pull/38478 + # and https://github.com/pytorch/pytorch/pull/38820) + # and modules emitted the v2 state_dict format while + # specifying that version == 1. This patches the forward + # compatibility issue by allowing the v2 style entries to + # be used. + pass + elif strict: + missing_keys.append(prefix + v2_name) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound + # has no __name__ (code is fine though) + assert type(mod) == cls._FLOAT_MODULE, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + qconfig = mod.qconfig + conv, bn = mod[0], mod[1] # type: ignore[index] + qat_convbn = cls( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + bn.eps, + bn.momentum, + False, + qconfig, + ) + qat_convbn.weight = conv.weight + qat_convbn.bias = conv.bias + qat_convbn.bn.weight = bn.weight + qat_convbn.bn.bias = bn.bias + qat_convbn.bn.running_mean = bn.running_mean + qat_convbn.bn.running_var = bn.running_var + # mypy error: Cannot determine type of 'num_batches_tracked' + qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked + return qat_convbn + + def to_float(self): + cls = type(self) + conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.padding_mode, + ) + conv.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + conv.bias = torch.nn.Parameter(self.bias.detach()) + + if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] + # fuse bn into conv + assert self.bn.running_var is not None and self.bn.running_mean is not None + conv.weight, conv.bias = fuse_conv_bn_weights( + conv.weight, + conv.bias, + self.bn.running_mean, + self.bn.running_var, + self.bn.eps, + self.bn.weight, + self.bn.bias, + ) + + if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined] + modules = [] + modules.append(conv) + relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] + modules.append(relu) + conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined] + conv_relu.train(self.training) + return conv_relu + else: + conv.train(self.training) + return conv + + +class ConvBn1d(_ConvBnNd, nn.Conv1d): + r""" + A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + + def __init__( + self, + # Conv1d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _single(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=1, + ) + + +class ConvBnReLU1d(ConvBn1d): + r""" + A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv1d` and + :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + # base class defines _FLOAT_MODULE as "ConvBn1d" + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBnReLU1d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU1d + + def __init__( + self, + # Conv1d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + + def forward(self, input): + return F.relu(self._forward(input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + +class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): + r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv1d` and + :class:`~torch.nn.BatchNorm1d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvBn2d(_ConvBnNd, nn.Conv2d): + r""" + A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv2d` and + :class:`torch.nn.BatchNorm2d`. + + Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm2d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=2, + ) + + +class ConvBnReLU2d(ConvBn2d): + r""" + A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv2d` and + :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + # base class defines _FLOAT_MODULE as "ConvBn2d" + _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm2d]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nni.ConvReLU2d]]] = nni.ConvReLU2d + + def __init__( + self, + # Conv2d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm2d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + + def forward(self, input): + return F.relu(self._forward(input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + +class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): + r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv2d` and + :class:`~torch.nn.BatchNorm2d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvBn3d(_ConvBnNd, nn.Conv3d): + r""" + A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv3d` and + :class:`torch.nn.BatchNorm3d`. + + Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + # ConvNd args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm3d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + _ConvBnNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + dim=3, + ) + + +class ConvBnReLU3d(ConvBn3d): + r""" + A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, + attached with FakeQuantize modules for weight, + used in quantization aware training. + + We combined the interface of :class:`torch.nn.Conv3d` and + :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`. + + Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm3d]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.ReLU]]] = nn.ReLU + # module class after fusing bn into conv + _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nni.ConvReLU3d]]] = nni.ConvReLU3d + + def __init__( + self, + # Conv3d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm3d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + + def forward(self, input): + return F.relu(ConvBn3d._forward(self, input)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): + r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with + FakeQuantize modules for weight for + quantization aware training. + + We combined the interface of :class:`~torch.nn.Conv3d` and + :class:`~torch.nn.BatchNorm3d`. + + Attributes: + weight_fake_quant: fake quant module for weight + + """ + + _FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment] + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = self.qconfig.weight() + + def forward(self, input): + return F.relu( + self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +def update_bn_stats(mod): + if type(mod) in { + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + }: + mod.update_bn_stats() + + +def freeze_bn_stats(mod): + if type(mod) in { + ConvBnReLU1d, + ConvBnReLU2d, + ConvBnReLU3d, + ConvBn1d, + ConvBn2d, + ConvBn3d, + }: + mod.freeze_bn_stats() diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..aada0ab2ab7144c38ece21c78fd3050c75e28062 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -0,0 +1,193 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.parameter import Parameter +from torch.nn.utils.fusion import fuse_linear_bn_weights + + +__all__ = [ + "LinearBn1d", +] + + +class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule): + r""" + A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached + with FakeQuantize modules for weight, used in quantization aware training. + + We combined the interface of :class:`torch.nn.Linear` and + :class:torch.nn.BatchNorm1d`. + + Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized + to default. + + Attributes: + freeze_bn: + weight_fake_quant: fake quant module for weight + + """ + + def __init__( + self, + # Linear args + in_features, + out_features, + bias=True, + # BatchNorm1d args + # num_features: out_features + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + nn.modules.linear.Linear.__init__(self, in_features, out_features, bias) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.freeze_bn = freeze_bn if self.training else True + self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) + self.weight_fake_quant = self.qconfig.weight() + if bias: + self.bias = Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + self.reset_bn_parameters() + + # this needs to be called after reset_bn_parameters, + # as they modify the same state + if self.training: + if freeze_bn: + self.freeze_bn_stats() + else: + self.update_bn_stats() + else: + self.freeze_bn_stats() + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def reset_bn_parameters(self): + self.bn.reset_running_stats() + init.uniform_(self.bn.weight) + init.zeros_(self.bn.bias) + + def reset_parameters(self): + super().reset_parameters() + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def forward(self, input): + assert self.bn.running_var is not None + + # Scale the linear weights by BN's running statistics to reduce + # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18 + # for motivation. + # + # Instead of + # + # x1 = F.linear(x0, fq(w), b) + # x2 = self.bn(x1) + # + # We have + # + # # scale the weight by previous batch's running statistics + # scale_factor = bn.w / bn.running_std_from_prev_batch + # # do the linear transformation without bias + # x1_scaled = F.linear(x0, fq(w * scale_factor), 0) + # # reverse the scaling and add original bias + # x1_orig = x1_scaled / scale_factor + b + # x2 = self.bn(x1_orig) + + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant( + self.weight * scale_factor.reshape(weight_shape) + ) + if self.bias is not None: + zero_bias = torch.zeros_like(self.bias) + else: + zero_bias = torch.zeros(self.out_features, device=scaled_weight.device) + linear_out = F.linear(input, scaled_weight, zero_bias) + linear_out_orig = linear_out / scale_factor.reshape(bias_shape) + if self.bias is not None: + linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape) + bn_out = self.bn(linear_out_orig) + return bn_out + + def train(self, mode=True): + """ + Batchnorm's training behavior is using the self.training flag. Prevent + changing it if BN is frozen. This makes sure that calling `model.train()` + on a model with a frozen BN will behave properly. + """ + self.training = mode + if not self.freeze_bn: + for module in self.children(): + module.train(mode) + return self + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + + Args: `mod' a float module, either produced by torch.ao.quantization + utilities or directly from user + """ + assert type(mod) == nni.LinearBn1d, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + nni.LinearBn1d.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid config" + qconfig = mod.qconfig + linear, bn = mod[0], mod[1] + qat_linearbn = cls( + linear.in_features, + linear.out_features, + linear.bias is not None, + bn.eps, + bn.momentum, + False, + qconfig, + ) + qat_linearbn.weight = linear.weight # type: ignore[assignment] + qat_linearbn.bias = linear.bias # type: ignore[assignment] + qat_linearbn.bn.weight = bn.weight # type: ignore[assignment] + qat_linearbn.bn.bias = bn.bias # type: ignore[assignment] + qat_linearbn.bn.running_mean = bn.running_mean # type: ignore[assignment] + qat_linearbn.bn.running_var = bn.running_var # type: ignore[assignment] + qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[assignment] + return qat_linearbn + + def to_float(self): + linear = torch.nn.Linear(self.in_features, self.out_features) + assert self.bn.running_var is not None and self.bn.running_mean is not None + linear.weight, linear.bias = fuse_linear_bn_weights( + self.weight, + self.bias, + self.bn.running_mean, + self.bn.running_var, + self.bn.eps, + self.bn.weight, + self.bn.bias, + ) + return linear diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..06cd482bb61a05deb0555d144aa707c178f01408 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -0,0 +1,52 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.qat as nnqat +import torch.nn.functional as F + + +class LinearReLU(nnqat.Linear, nni._FusedModule): + r""" + A LinearReLU module fused from Linear and ReLU modules, attached with + FakeQuantize modules for weight, used in + quantization aware training. + + We adopt the same interface as :class:`torch.nn.Linear`. + + Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.qat.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, qconfig=None): + super().__init__(in_features, out_features, bias, qconfig) + + def forward(self, input): + return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + def to_float(self): + linear = torch.nn.Linear( + self.in_features, self.out_features, self.bias is not None + ) + linear.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + linear.bias = torch.nn.Parameter(self.bias.detach()) + relu = torch.nn.ReLU() + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6af3b4aeee893966323cc4e73a27ff41814fc251 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py @@ -0,0 +1,15 @@ +from .modules import * # noqa: F403 + + +__all__ = [ + "BNReLU2d", + "BNReLU3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "LinearReLU", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..892897e4b4308af185673095833a07933a962db5 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a5e28d98fbe0cc7d0564f9069f88989de7725a8 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a6c3c57c7828861b574e76b134aee2c23f0aad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,6 @@ +from .linear_relu import LinearReLU + + +__all__ = [ + "LinearReLU", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9848ff1118bb13f502598a8b13846f7e2d4783c3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9052a1caa4213bade11f313e7901eae49c274f21 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..f19c2c8e9d9db8ae0828d5d4601b3be8dba11231 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized.dynamic as nnqd + + +__all__ = ["LinearReLU"] + + +class LinearReLU(nnqd.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules that can be used + for dynamic quantization. + Supports both, FP16 and INT8 quantization. + + We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.dynamic.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._packed_params.dtype == torch.qint8: + # TODO check if we should set reduce_rage = True by default here + Y = torch.ops.quantized.linear_relu_dynamic( + x, self._packed_params._packed_params, reduce_range=True + ) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_relu_dynamic_fp16( + x, self._packed_params._packed_params + ) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!") + return Y.to(x.dtype) + + def _get_name(self): + return "DynamicQuantizedLinearReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qlinear_relu): # type: ignore[override] + return super().from_reference(ref_qlinear_relu[0]) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fa4dcec2597e18c002489405894ea7251d5156 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py @@ -0,0 +1,18 @@ +from .bn_relu import BNReLU2d, BNReLU3d +from .conv_add import ConvAdd2d, ConvAddReLU2d +from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d +from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh + + +__all__ = [ + "LinearReLU", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "BNReLU2d", + "BNReLU3d", + "LinearLeakyReLU", + "LinearTanh", + "ConvAdd2d", + "ConvAddReLU2d", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5fdc6076997776cfd59ffce103a80ed78e61f9c Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0563ac8086578f3eda195731f09227b148fc037 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af96dad88f0b398dd8ca2e624b0ce89e27c0f4f8 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d241cced3445cd46887734285896cd51d9f13df Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..266c94fc1f0dfffd9c181af8d07dfc690f813a26 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..99b535625cbc7e3beb888ed1c61fa1e1b114853e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs + +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq + + +__all__ = ["BNReLU2d", "BNReLU3d"] + + +class BNReLU2d(nnq.BatchNorm2d): + r""" + A BNReLU2d module is a fused module of BatchNorm2d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`. + + Attributes: + Same as torch.ao.nn.quantized.BatchNorm2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + super().__init__( + num_features, eps=eps, momentum=momentum, device=device, dtype=dtype + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return torch.ops.quantized.batch_norm2d_relu( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedBNReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + # TODO: Add qat support for BNReLU2d + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, bn_relu, output_scale, output_zero_point): + return super().from_reference(bn_relu[0], output_scale, output_zero_point) + + +class BNReLU3d(nnq.BatchNorm3d): + r""" + A BNReLU3d module is a fused module of BatchNorm3d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`. + + Attributes: + Same as torch.ao.nn.quantized.BatchNorm3d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + super().__init__( + num_features, eps=eps, momentum=momentum, device=device, dtype=dtype + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + return torch.ops.quantized.batch_norm3d_relu( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedBNReLU3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + # TODO: Add qat support for BNReLU3d + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, bn_relu, output_scale, output_zero_point): + return super().from_reference(bn_relu[0], output_scale, output_zero_point) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py new file mode 100644 index 0000000000000000000000000000000000000000..71bfa845f150ae09745ce1c6941b16b2c6583fd8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -0,0 +1,147 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq +import torch.nn.functional as F + + +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + + +class ConvAdd2d(nnq.Conv2d): + r""" + A ConvAdd2d module is a fused module of Conv2d and Add + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input, extra_input): # type: ignore[override] + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_add( + input, extra_input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvAdd2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvAddReLU2d(nnq.Conv2d): + r""" + A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input, extra_input): # type: ignore[override] + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_add_relu( + input, extra_input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvAddReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..8172004d95fc800dee989d4c47382a600eb01fd9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -0,0 +1,266 @@ +# mypy: allow-untyped-defs + +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.ao.nn.quantized as nnq +import torch.nn.functional as F +from torch.nn.utils import fuse_conv_bn_weights + + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", +] + +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + + +# TODO: factor out the common parts to ConvNd +class ConvReLU1d(nnq.Conv1d): + r""" + A ConvReLU1d module is a fused module of Conv1d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv1d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv1d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU1d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float(mod, use_precomputed_fake_quant) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, ( + "BatchNorm1d should be fused into Conv1d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvReLU2d(nnq.Conv2d): + r""" + A ConvReLU2d module is a fused module of Conv2d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv2d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, ( + "BatchNorm2d should be fused into Conv2d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvReLU3d(nnq.Conv3d): + r""" + A ConvReLU3d module is a fused module of Conv3d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`. + + Attributes: Same as torch.ao.nn.quantized.Conv3d + + """ + + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return torch.ops.quantized.conv3d_relu( + input, self._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedConvReLU3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, ( + "BatchNorm3d should be fused into Conv3d before converting to reference module" + ) + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff5a7e4029fa58b9ee476fef934ea3bab8ea689 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -0,0 +1,190 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized as nnq +from torch.ao.nn.quantized.modules.utils import _quantize_weight + + +__all__ = [ + "LinearReLU", + "LinearLeakyReLU", + "LinearTanh", +] + + +class LinearReLU(nnq.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_relu( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLinearReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + @classmethod + def from_reference(cls, ref_linear_relu, output_scale, output_zero_point): + return super().from_reference( + ref_linear_relu[0], output_scale, output_zero_point + ) + + +class LinearLeakyReLU(nnq.Linear): + r""" + For onednn backend only + A LinearLeakyReLU module fused from Linear and LeakyReLU modules + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + Attributes: + Same as torch.ao.nn.quantized.Linear + + negative_slope + Examples:: + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment] + + def __init__( + self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8 + ): + super().__init__(in_features, out_features, bias, dtype) + self.negative_slope = negative_slope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_leaky_relu( + x, + self._packed_params._packed_params, + self.scale, + self.zero_point, + self.negative_slope, + ) + + def _get_name(self): + return "QuantizedLinearLeakyReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) == nni.LinearLeakyReLU, ( + "Input float module should be LinearLeakyReLU" + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + leaky_relu = mod[1] + mod = mod[0] + weight_post_process = mod.qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_leaky_relu = cls( + mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype + ) + qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_leaky_relu.scale = float(act_scale) + qlinear_leaky_relu.zero_point = int(act_zp) + return qlinear_leaky_relu + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + leaky_relu = ref_mod[1] + qlinear_leaky_relu = cls( + linear.in_features, linear.out_features, leaky_relu.negative_slope + ) + qweight = linear.get_quantized_weight() + qlinear_leaky_relu.set_weight_bias(qweight, linear.bias) + qlinear_leaky_relu.scale = float(output_scale) + qlinear_leaky_relu.zero_point = int(output_zero_point) + return qlinear_leaky_relu + + +class LinearTanh(nnq.Linear): + r""" + A LinearTanh module fused from Linear and Tanh modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearTanh(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_tanh( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLinearTanh" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) == nni.LinearTanh, "Input float module should be LinearTanh" + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + mod = mod[0] + weight_post_process = mod.qconfig.weight() # type: ignore[union-attr,operator] + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear_tanh.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_tanh.scale = float(act_scale) + qlinear_tanh.zero_point = int(act_zp) + return qlinear_tanh + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + qlinear_tanh = cls(linear.in_features, linear.out_features) + qweight = linear.get_quantized_weight() + qlinear_tanh.set_weight_bias(qweight, linear.bias) + qlinear_tanh.scale = float(output_scale) + qlinear_tanh.zero_point = int(output_zero_point) + return qlinear_tanh diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b65f1f2cf73608d292edb174d9602c99638ec773 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..028939f6e4c163fbd6545a294501e92abebea6de Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dca71fcf09b019f3e197576eb415ba4fd54fa28a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__init__.py @@ -0,0 +1,4 @@ +from .linear import Linear + + +__all__ = ["Linear"] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8a65e40419d2dfca9de01484608395a56c8c799 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e30b26fb52fcb3bb17c2420c4e99e501aef8ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,40 @@ +from typing import Optional, TYPE_CHECKING, Union + +import torch + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfig # noqa: TC004 + + +__all__ = ["Linear"] + + +class Linear(torch.ao.nn.qat.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for dynamic quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + qconfig: Optional["QConfig"] = None, + device: Optional[Union[int, str, torch.device]] = None, + dtype: Optional[str] = None, + ) -> None: + super().__init__(in_features, out_features, bias, qconfig, device, dtype) + if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig): # type: ignore[arg-type] + raise ValueError( + "Dynamic QAT requires a memoryless observer." + + "This means a MovingAverage observer with averaging constant equal to 1" + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e28e0968a60d7612ebbd26d5f607b4407c2d380 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__init__.py @@ -0,0 +1,13 @@ +from .conv import Conv1d, Conv2d, Conv3d +from .embedding_ops import Embedding, EmbeddingBag +from .linear import Linear + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a921a51e9ec4bedec0a308bbecd8c7985c09634d Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a214795ca36faec8bef93c2cdada9430dc6a0ab4 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28bb6a194f9af4afca502a08e877b96e80623145 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..498a811565a188295bc865d2a30e8dbfd1edafe6 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/conv.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..90474ab1ce60cb3bd8600af80599865f94b5e23d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/conv.py @@ -0,0 +1,311 @@ +# mypy: allow-untyped-defs +from typing import ClassVar, Union + +import torch +import torch.nn as nn +from torch.ao.nn.intrinsic import _FusedModule +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t +from torch.nn.modules.utils import _pair, _single, _triple + + +__all__ = ["Conv1d", "Conv2d", "Conv3d"] + + +class _ConvNd(nn.modules.conv._ConvNd): + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: tuple[int, ...], + stride: tuple[int, ...], + padding: Union[str, tuple[int, ...]], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + nn.modules.conv._ConvNd.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: + `mod`: a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, ( + "qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + if issubclass(type(mod), _FusedModule): + mod = mod[0] + qconfig = mod.qconfig + qat_conv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + stride=mod.stride, + padding=mod.padding, + dilation=mod.dilation, + groups=mod.groups, + bias=mod.bias is not None, + padding_mode=mod.padding_mode, + qconfig=qconfig, + ) + qat_conv.weight = mod.weight + qat_conv.bias = mod.bias + return qat_conv + + def to_float(self): + """This works for both single qat conv, and the qat conv - relu modules + to convert the qat module to a floating point module + """ + cls = type(self) + conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.padding_mode, + ) + conv.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + conv.bias = torch.nn.Parameter(self.bias.detach()) + # conv relu + if issubclass(cls, _FusedModule): + modules = [conv] + assert hasattr(cls, "_FLOAT_RELU_MODULE") + relu = cls._FLOAT_RELU_MODULE() + modules.append(relu) + fused = cls._FLOAT_MODULE(*modules) + fused.train(self.training) + return fused + else: + return conv + + +class Conv1d(_ConvNd, nn.Conv1d): + r""" + A Conv1d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as :class:`~torch.nn.Conv1d` + + Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: Union[str, _size_1_t] = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_single(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv2d(_ConvNd, nn.Conv2d): + r""" + A Conv2d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Conv2d`, please see + https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d + for documentation. + + Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_pair(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv3d(_ConvNd, nn.Conv3d): + r""" + A Conv3d module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Conv3d`, please see + https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d + for documentation. + + Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to + default. + + Attributes: + weight_fake_quant: fake quant module for weight + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: Union[str, _size_3_t] = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + qconfig=None, + device=None, + dtype=None, + ) -> None: + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = padding if isinstance(padding, str) else _triple(padding) + dilation_ = _triple(dilation) + super().__init__( + in_channels, + out_channels, + kernel_size_, + stride=stride_, + padding=padding_, + dilation=dilation_, + transposed=False, + output_padding=_triple(0), + groups=groups, + bias=bias, + padding_mode=padding_mode, + qconfig=qconfig, + device=device, + dtype=dtype, + ) + + def forward(self, input): + return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return super().from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/embedding_ops.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..13fd7a5983fbee566ed449c54b325100ee1fbc13 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/embedding_ops.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(nn.Embedding): + r""" + An embedding bag module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Embedding`, please see + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding + for documentation. + + Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.Embedding + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + qconfig=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(qconfig.weight().qscheme) + ) + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input) -> Tensor: + return F.embedding( + input, + self.weight_fake_quant(self.weight), + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator] + assert weight_qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(weight_qscheme) + ) + + qconfig = mod.qconfig + qat_embedding_bag = cls( + mod.num_embeddings, + mod.embedding_dim, + mod.padding_idx, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.sparse, + mod.weight, + qconfig=qconfig, + ) + + return qat_embedding_bag + + def to_float(self): + embedding_bag = torch.nn.Embedding( + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + None, + ) + embedding_bag.weight = torch.nn.Parameter(self.weight.detach()) + embedding_bag.train(self.training) + return embedding_bag + + +class EmbeddingBag(nn.EmbeddingBag): + r""" + An embedding bag module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.EmbeddingBag`, please see + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag + for documentation. + + Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.EmbeddingBag + + def __init__( + self, + num_embeddings, + embedding_dim, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + mode="mean", + sparse=False, + _weight=None, + include_last_offset=False, + padding_idx=None, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_embeddings, + embedding_dim, + max_norm, + norm_type, + scale_grad_by_freq, + mode, + sparse, + _weight, + include_last_offset, + padding_idx, + **factory_kwargs, + ) + assert qconfig, "qconfig must be provided for QAT module" + assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(qconfig.weight().qscheme) + ) + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor: + return F.embedding_bag( + input, + self.weight_fake_quant(self.weight), + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module + + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator] + assert weight_qscheme == torch.per_channel_affine_float_qparams, ( + "Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got " + + str(weight_qscheme) + ) + + qconfig = mod.qconfig + qat_embedding_bag = cls( + mod.num_embeddings, + mod.embedding_dim, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.mode, + mod.sparse, + mod.weight, + mod.include_last_offset, + mod.padding_idx, + qconfig=qconfig, + ) + + return qat_embedding_bag + + def to_float(self): + embedding_bag = torch.nn.EmbeddingBag( + self.num_embeddings, + self.embedding_dim, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + None, + self.include_last_offset, + self.padding_idx, + ) + embedding_bag.weight = torch.nn.Parameter(self.weight.detach()) + embedding_bag.train(self.training) + return embedding_bag diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/linear.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5edf16ed3ea53d0323eda248b95703d5245b1786 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/qat/modules/linear.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.ao.nn.intrinsic import LinearReLU +from torch.nn.utils.parametrize import ( + is_parametrized, + transfer_parametrizations_and_params, + type_before_parametrizations, +) + + +__all__ = ["Linear"] + + +class Linear(nn.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + + _FLOAT_MODULE = nn.Linear + + def __init__( + self, + in_features, + out_features, + bias=True, + qconfig=None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(in_features, out_features, bias, **factory_kwargs) + assert qconfig, "qconfig must be provided for QAT module" + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input): + return F.linear(input, self.weight_fake_quant(self.weight), self.bias) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a qat module from a float module or qparams_dict + Args: `mod` a float module, either produced by torch.ao.quantization utilities + or directly from user + """ + assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, ( + " qat." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + assert mod.qconfig, "Input float module must have a valid qconfig" + if type_before_parametrizations(mod) == LinearReLU: + mod = mod[0] + + qconfig = mod.qconfig + qat_linear = cls( + mod.in_features, + mod.out_features, + bias=mod.bias is not None, + qconfig=qconfig, + ) + + if is_parametrized(mod, "weight"): + transfer_parametrizations_and_params(mod, qat_linear, "weight") + else: + qat_linear.weight = mod.weight + + if is_parametrized(mod, "bias"): + transfer_parametrizations_and_params(mod, qat_linear, "bias") + else: + qat_linear.bias = mod.bias + + return qat_linear + + def to_float(self): + linear = torch.nn.Linear( + self.in_features, self.out_features, self.bias is not None + ) + linear.weight = torch.nn.Parameter(self.weight.detach()) + if self.bias is not None: + linear.bias = torch.nn.Parameter(self.bias.detach()) + linear.train(self.training) + return linear diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2d0f67fb0c35d23a13c1f09e2a0d8af60654168 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..221107660158171ada5d1823cc193666c9e152e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from .activation import MultiheadAttention +from .rnn import LSTM, LSTMCell + + +__all__ = [ + "LSTM", + "LSTMCell", + "MultiheadAttention", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3db7a9a3db7593cd4cc62876bd855360f6840e0f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fe74188f0f2b055f91ea1e99de746ffbe4f53a3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ac6eb0ae3f818f2a163a3342f41eb5c4860a8f1 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/activation.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..7f6da0815116b991e827f58340d569b12f101637 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/activation.py @@ -0,0 +1,552 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Optional + +import torch +import torch.jit # this is needed to avoid a circular import +import torch.nn.functional as F +from torch import nn, Tensor + + +__all__ = ["MultiheadAttention"] + + +class MultiheadAttention(nn.MultiheadAttention): + _FLOAT_MODULE = nn.MultiheadAttention + + r"""Quantizable implementation of the MultiheadAttention. + + Note:: + Please, refer to :class:`~torch.nn.MultiheadAttention` for more + information + + Allows the model to jointly attend to information from different + representation subspaces. + See reference: Attention Is All You Need + + The original MHA module is not quantizable. + This reimplements it by explicitly instantiating the linear layers. + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set + to :attr:`embed_dim` such that query, key, and value have the same + number of features. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + Note:: + Please, follow the quantization flow to convert the quantizable MHA. + """ + __constants__ = ["batch_first"] + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + add_bias_kv: bool = False, + add_zero_attn: bool = False, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + batch_first: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + add_bias_kv, + add_zero_attn, + kdim, + vdim, + batch_first, + **factory_kwargs, + ) + self.linear_Q = nn.Linear( + self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs + ) + self.linear_K = nn.Linear( + self.kdim, self.embed_dim, bias=bias, **factory_kwargs + ) + self.linear_V = nn.Linear( + self.vdim, self.embed_dim, bias=bias, **factory_kwargs + ) + # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969 + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs + ) # type: ignore[assignment] + + # Functionals + self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional() + # note: importing torch.ao.nn.quantized at top creates a circular import + + # Quant/Dequant + self.quant_attn_output = torch.ao.quantization.QuantStub() + self.quant_attn_output_weights = torch.ao.quantization.QuantStub() + self.dequant_q = torch.ao.quantization.DeQuantStub() + self.dequant_k = torch.ao.quantization.DeQuantStub() + self.dequant_v = torch.ao.quantization.DeQuantStub() + + def _get_name(self): + return "QuantizableMultiheadAttention" + + @classmethod + def from_float(cls, other): + assert type(other) == cls._FLOAT_MODULE + assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" + # Setting the dropout to 0.0! + observed = cls( + other.embed_dim, + other.num_heads, + other.dropout, + (other.in_proj_bias is not None), + (other.bias_k is not None), + other.add_zero_attn, + other.kdim, + other.vdim, + other.batch_first, + ) + observed.bias_k = other.bias_k + observed.bias_v = other.bias_v + observed.qconfig = other.qconfig + + # Set the linear weights + # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969 + observed.out_proj.weight = other.out_proj.weight + observed.out_proj.bias = other.out_proj.bias + if other._qkv_same_embed_dim: + # Use separate params + bias = other.in_proj_bias + _start = 0 + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_Q.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_Q.bias = bias + + bias = other.in_proj_bias + _start = _end + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_K.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_K.bias = bias + + bias = other.in_proj_bias + _start = _end + weight = other.in_proj_weight[_start:, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:], bias.requires_grad) + observed.linear_V.weight = torch.nn.Parameter(weight, weight.requires_grad) + observed.linear_V.bias = bias + else: + observed.linear_Q.weight = nn.Parameter(other.q_proj_weight) + observed.linear_K.weight = nn.Parameter(other.k_proj_weight) + observed.linear_V.weight = nn.Parameter(other.v_proj_weight) + if other.in_proj_bias is None: + observed.linear_Q.bias = None + observed.linear_K.bias = None + observed.linear_V.bias = None + else: + observed.linear_Q.bias = nn.Parameter( + other.in_proj_bias[0 : other.embed_dim] + ) + observed.linear_K.bias = nn.Parameter( + other.in_proj_bias[other.embed_dim : (other.embed_dim * 2)] + ) + observed.linear_V.bias = nn.Parameter( + other.in_proj_bias[(other.embed_dim * 2) :] + ) + observed.eval() + # Explicit prepare + observed = torch.ao.quantization.prepare(observed, inplace=True) + return observed + + @torch.jit.unused + def dequantize(self): + r"""Utility to convert the quantized MHA back to float. + + The motivation for this is that it is not trivial to convert the weights + from the format that is used in the quantized version back to the + float. + """ + fp = self._FLOAT_MODULE( + self.embed_dim, + self.num_heads, + self.dropout, + (self.linear_Q._weight_bias()[1] is not None), # type: ignore[operator] + (self.bias_k is not None), + self.add_zero_attn, + self.kdim, + self.vdim, + self.batch_first, + ) + assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim + if self.bias_k is not None: + fp.bias_k = nn.Parameter(self.bias_k.dequantize()) + if self.bias_v is not None: + fp.bias_v = nn.Parameter(self.bias_v.dequantize()) + + # Set the linear weights + # Note: Because the linear layers are quantized, mypy does not nkow how + # to deal with them -- might need to ignore the typing checks. + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] + fp.out_proj.weight = nn.Parameter(w.dequantize()) + if b is not None: + fp.out_proj.bias = nn.Parameter(b) + + wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator] + wQ = wQ.dequantize() + wK, bK = self.linear_K._weight_bias() # type: ignore[operator] + wK = wK.dequantize() + wV, bV = self.linear_V._weight_bias() # type: ignore[operator] + wV = wV.dequantize() + if fp._qkv_same_embed_dim: + # Use separate params + _start = 0 + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wQ + if fp.in_proj_bias is not None: + assert all(bQ == 0) + fp.in_proj_bias[_start:_end] = bQ + + _start = _end + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wK + if fp.in_proj_bias is not None: + assert all(bK == 0) + fp.in_proj_bias[_start:_end] = bK + + _start = _end + fp.in_proj_weight[_start:, :] = wV + if fp.in_proj_bias is not None: + assert all(bV == 0) + fp.in_proj_bias[_start:] = bV + else: + fp.q_proj_weight = nn.Parameter(wQ) + fp.k_proj_weight = nn.Parameter(wK) + fp.v_proj_weight = nn.Parameter(wV) + if fp.in_proj_bias is None: + self.linear_Q.bias = None + self.linear_K.bias = None + self.linear_V.bias = None + else: + fp.in_proj_bias[0 : fp.embed_dim] = bQ + fp.in_proj_bias[fp.embed_dim : (fp.embed_dim * 2)] = bK + fp.in_proj_bias[(fp.embed_dim * 2) :] = bV + + return fp + + @classmethod + def from_observed(cls, other): + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + # See nn.quantized.MultiheadAttention + raise NotImplementedError( + "It looks like you are trying to prepare an " + "MHA module. Please, see " + "the examples on quantizable MHAs." + ) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Optional[Tensor]]: + r""" + Note:: + Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more + information + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. + Default: ``False``. + - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged + across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, + S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(N, num_heads, L, S)`. + """ + return self._forward_impl( + query, + key, + value, + key_padding_mask, + need_weights, + attn_mask, + average_attn_weights, + is_causal, + ) + + def _forward_impl( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Optional[Tensor]]: + # This version will not deal with the static key/value pairs. + # Keeping it here for future changes. + # + # TODO: This method has some duplicate lines with the + # `torch.nn.functional.multi_head_attention`. Will need to refactor. + static_k = None + static_v = None + + if attn_mask is not None and is_causal: + raise AssertionError("Only allow causal mask or attn_mask") + + if is_causal: + raise AssertionError("causal mask not supported by AO MHA module") + + if self.batch_first: + query, key, value = (x.transpose(0, 1) for x in (query, key, value)) + + tgt_len, bsz, embed_dim_to_check = query.size() + assert self.embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = self.embed_dim // self.num_heads + assert head_dim * self.num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) + scaling = float(head_dim) ** -0.5 + + q = self.linear_Q(query) + k = self.linear_K(key) + v = self.linear_V(value) + + q = self.q_scaling_product.mul_scalar(q, scaling) + + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) + attn_mask = attn_mask.to(torch.bool) + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, ( + f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" + ) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * self.num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " + "Use bool tensor instead.", + stacklevel=3, + ) + key_padding_mask = key_padding_mask.to(torch.bool) + if self.bias_k is not None and self.bias_v is not None: + if static_k is None and static_v is None: + # Explicitly assert that bias_k and bias_v are not None + # in a way that TorchScript can understand. + bias_k = self.bias_k + assert bias_k is not None + bias_v = self.bias_v + assert bias_v is not None + + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert self.bias_k is None + assert self.bias_v is None + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * self.num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * self.num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) + if k.is_quantized: + k_zeros = torch.quantize_per_tensor( + k_zeros, k.q_scale(), k.q_zero_point(), k.dtype + ) + k = torch.cat([k, k_zeros], dim=1) + v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) + if v.is_quantized: + v_zeros = torch.quantize_per_tensor( + v_zeros, v.q_scale(), v.q_zero_point(), v.dtype + ) + v = torch.cat([v, v_zeros], dim=1) + + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + # Leaving the quantized zone here + q = self.dequant_q(q) + k = self.dequant_k(k) + v = self.dequant_v(v) + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [ + bsz * self.num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * self.num_heads, tgt_len, src_len + ) + + attn_output_weights = F.softmax(attn_output_weights, dim=-1) + attn_output_weights = F.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] + if self.batch_first: + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + else: + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, self.embed_dim) + ) + + # Reentering the quantized zone + attn_output = self.quant_attn_output(attn_output) + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + attn_output = self.out_proj(attn_output) # type: ignore[has-type] + attn_output_weights = self.quant_attn_output_weights(attn_output_weights) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + return attn_output, attn_output_weights + else: + return attn_output, None diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/rnn.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ad32cf174c6280149ff967fc54cd8e439d7c29f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantizable/modules/rnn.py @@ -0,0 +1,599 @@ +""" +We will recreate all the RNN modules as we require the modules to be decomposed +into its building blocks to be able to observe. +""" + +# mypy: allow-untyped-defs + +import numbers +import warnings +from typing import Optional + +import torch +from torch import Tensor + + +__all__ = ["LSTMCell", "LSTM"] + + +class LSTMCell(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM) cell. + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` + + `split_gates`: specify True to compute the input/forget/cell/output gates separately + to avoid an intermediate tensor which is subsequently chunk'd. This optimization can + be beneficial for on-device inference latency. This flag is cascaded down from the + parent classes. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> rnn = nnqa.LSTMCell(10, 20) + >>> input = torch.randn(6, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + """ + + _FLOAT_MODULE = torch.nn.LSTMCell + __constants__ = ["split_gates"] # for jit.script + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_dim + self.hidden_size = hidden_dim + self.bias = bias + self.split_gates = split_gates + + if not split_gates: + self.igates: torch.nn.Module = torch.nn.Linear( + input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs + ) + self.hgates: torch.nn.Module = torch.nn.Linear( + hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs + ) + self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional() + else: + # keep separate Linear layers for each gate + self.igates = torch.nn.ModuleDict() + self.hgates = torch.nn.ModuleDict() + self.gates = torch.nn.ModuleDict() + for g in ["input", "forget", "cell", "output"]: + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.igates[g] = torch.nn.Linear( + input_dim, hidden_dim, bias=bias, **factory_kwargs + ) + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.hgates[g] = torch.nn.Linear( + hidden_dim, hidden_dim, bias=bias, **factory_kwargs + ) + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + self.gates[g] = torch.ao.nn.quantized.FloatFunctional() + + self.input_gate = torch.nn.Sigmoid() + self.forget_gate = torch.nn.Sigmoid() + self.cell_gate = torch.nn.Tanh() + self.output_gate = torch.nn.Sigmoid() + + self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() + self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() + self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() + + self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() + + self.initial_hidden_state_qparams: tuple[float, int] = (1.0, 0) + self.initial_cell_state_qparams: tuple[float, int] = (1.0, 0) + self.hidden_state_dtype: torch.dtype = torch.quint8 + self.cell_state_dtype: torch.dtype = torch.quint8 + + def forward( + self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + if hidden is None or hidden[0] is None or hidden[1] is None: + hidden = self.initialize_hidden(x.shape[0], x.is_quantized) + hx, cx = hidden + + if not self.split_gates: + igates = self.igates(x) + hgates = self.hgates(hx) + gates = self.gates.add(igates, hgates) # type: ignore[operator] + + input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) + + input_gate = self.input_gate(input_gate) + forget_gate = self.forget_gate(forget_gate) + cell_gate = self.cell_gate(cell_gate) + out_gate = self.output_gate(out_gate) + else: + # apply each input + hidden projection and add together + gate = {} + for (key, gates), igates, hgates in zip( + self.gates.items(), # type: ignore[operator] + self.igates.values(), # type: ignore[operator] + self.hgates.values(), # type: ignore[operator] + ): + gate[key] = gates.add(igates(x), hgates(hx)) + + input_gate = self.input_gate(gate["input"]) + forget_gate = self.forget_gate(gate["forget"]) + cell_gate = self.cell_gate(gate["cell"]) + out_gate = self.output_gate(gate["output"]) + + fgate_cx = self.fgate_cx.mul(forget_gate, cx) + igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) + fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) + cy = fgate_cx_igate_cgate + + # TODO: make this tanh a member of the module so its qparams can be configured + tanh_cy = torch.tanh(cy) + hy = self.ogate_cy.mul(out_gate, tanh_cy) + return hy, cy + + def initialize_hidden( + self, batch_size: int, is_quantized: bool = False + ) -> tuple[Tensor, Tensor]: + h, c = ( + torch.zeros((batch_size, self.hidden_size)), + torch.zeros((batch_size, self.hidden_size)), + ) + if is_quantized: + (h_scale, h_zp) = self.initial_hidden_state_qparams + (c_scale, c_zp) = self.initial_cell_state_qparams + h = torch.quantize_per_tensor( + h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype + ) + c = torch.quantize_per_tensor( + c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype + ) + return h, c + + def _get_name(self): + return "QuantizableLSTMCell" + + @classmethod + def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False): + """Uses the weights and biases to create a new LSTM cell. + + Args: + wi, wh: Weights for the input and hidden layers + bi, bh: Biases for the input and hidden layers + """ + assert (bi is None) == (bh is None) # Either both None or both have values + input_size = wi.shape[1] + hidden_size = wh.shape[1] + cell = cls( + input_dim=input_size, + hidden_dim=hidden_size, + bias=(bi is not None), + split_gates=split_gates, + ) + + if not split_gates: + cell.igates.weight = torch.nn.Parameter(wi) + if bi is not None: + cell.igates.bias = torch.nn.Parameter(bi) + cell.hgates.weight = torch.nn.Parameter(wh) + if bh is not None: + cell.hgates.bias = torch.nn.Parameter(bh) + else: + # split weight/bias + for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]): + for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()): # type: ignore[operator] + gate.weight = torch.nn.Parameter(w_chunk) + + if b is not None: + for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()): # type: ignore[operator] + gate.bias = torch.nn.Parameter(b_chunk) + + return cell + + @classmethod + def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False): + assert type(other) == cls._FLOAT_MODULE + assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" + observed = cls.from_params( + other.weight_ih, + other.weight_hh, + other.bias_ih, + other.bias_hh, + split_gates=split_gates, + ) + observed.qconfig = other.qconfig + observed.igates.qconfig = other.qconfig + observed.hgates.qconfig = other.qconfig + if split_gates: + # also apply qconfig directly to Linear modules + for g in observed.igates.values(): + g.qconfig = other.qconfig + for g in observed.hgates.values(): + g.qconfig = other.qconfig + return observed + + +class _LSTMSingleLayer(torch.nn.Module): + r"""A single one-directional LSTM layer. + + The difference between a layer and a cell is that the layer can process a + sequence, while the cell only expects an instantaneous value. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.cell = LSTMCell( + input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs + ) + + def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None): + result = [] + seq_len = x.shape[0] + for i in range(seq_len): + hidden = self.cell(x[i], hidden) + result.append(hidden[0]) # type: ignore[index] + result_tensor = torch.stack(result, 0) + return result_tensor, hidden + + @classmethod + def from_params(cls, *args, **kwargs): + cell = LSTMCell.from_params(*args, **kwargs) + layer = cls( + cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates + ) + layer.cell = cell + return layer + + +class _LSTMLayer(torch.nn.Module): + r"""A single bi-directional LSTM layer.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + bias: bool = True, + batch_first: bool = False, + bidirectional: bool = False, + device=None, + dtype=None, + *, + split_gates=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.batch_first = batch_first + self.bidirectional = bidirectional + self.layer_fw = _LSTMSingleLayer( + input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs + ) + if self.bidirectional: + self.layer_bw = _LSTMSingleLayer( + input_dim, + hidden_dim, + bias=bias, + split_gates=split_gates, + **factory_kwargs, + ) + + def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + if hidden is None: + hx_fw, cx_fw = (None, None) + else: + hx_fw, cx_fw = hidden + hidden_bw: Optional[tuple[Tensor, Tensor]] = None + if self.bidirectional: + if hx_fw is None: + hx_bw = None + else: + hx_bw = hx_fw[1] + hx_fw = hx_fw[0] + if cx_fw is None: + cx_bw = None + else: + cx_bw = cx_fw[1] + cx_fw = cx_fw[0] + if hx_bw is not None and cx_bw is not None: + hidden_bw = hx_bw, cx_bw + if hx_fw is None and cx_fw is None: + hidden_fw = None + else: + hidden_fw = ( + torch.jit._unwrap_optional(hx_fw), + torch.jit._unwrap_optional(cx_fw), + ) + result_fw, hidden_fw = self.layer_fw(x, hidden_fw) + + if hasattr(self, "layer_bw") and self.bidirectional: + x_reversed = x.flip(0) + result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) + result_bw = result_bw.flip(0) + + result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) + if hidden_fw is None and hidden_bw is None: + h = None + c = None + elif hidden_fw is None: + (h, c) = torch.jit._unwrap_optional(hidden_bw) + elif hidden_bw is None: + (h, c) = torch.jit._unwrap_optional(hidden_fw) + else: + h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] + c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] + else: + result = result_fw + h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] + + if self.batch_first: + result.transpose_(0, 1) + + return result, (h, c) + + @classmethod + def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): + r""" + There is no FP equivalent of this class. This function is here just to + mimic the behavior of the `prepare` within the `torch.ao.quantization` + flow. + """ + assert hasattr(other, "qconfig") or (qconfig is not None) + + input_size = kwargs.get("input_size", other.input_size) + hidden_size = kwargs.get("hidden_size", other.hidden_size) + bias = kwargs.get("bias", other.bias) + batch_first = kwargs.get("batch_first", other.batch_first) + bidirectional = kwargs.get("bidirectional", other.bidirectional) + split_gates = kwargs.get("split_gates", False) + + layer = cls( + input_size, + hidden_size, + bias, + batch_first, + bidirectional, + split_gates=split_gates, + ) + layer.qconfig = getattr(other, "qconfig", qconfig) + wi = getattr(other, f"weight_ih_l{layer_idx}") + wh = getattr(other, f"weight_hh_l{layer_idx}") + bi = getattr(other, f"bias_ih_l{layer_idx}", None) + bh = getattr(other, f"bias_hh_l{layer_idx}", None) + + layer.layer_fw = _LSTMSingleLayer.from_params( + wi, wh, bi, bh, split_gates=split_gates + ) + + if other.bidirectional: + wi = getattr(other, f"weight_ih_l{layer_idx}_reverse") + wh = getattr(other, f"weight_hh_l{layer_idx}_reverse") + bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None) + bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None) + layer.layer_bw = _LSTMSingleLayer.from_params( + wi, wh, bi, bh, split_gates=split_gates + ) + return layer + + +class LSTM(torch.nn.Module): + r"""A quantizable long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples below. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> rnn = nnqa.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + >>> # To get the weights: + >>> # xdoctest: +SKIP + >>> print(rnn.layers[0].weight_ih) + tensor([[...]]) + >>> print(rnn.layers[0].weight_hh) + AssertionError: There is no reverse path in the non-bidirectional layer + """ + + _FLOAT_MODULE = torch.nn.LSTM + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + *, + split_gates: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.training = False # Default to eval mode. If we want to train, we will explicitly set to training. + + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0: + warnings.warn( + "dropout option for quantizable LSTM is ignored. " + "If you are training, please, use nn.LSTM version " + "followed by `prepare` step." + ) + if num_layers == 1: + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} " + f"and num_layers={num_layers}" + ) + + layers = [ + _LSTMLayer( + self.input_size, + self.hidden_size, + self.bias, + batch_first=False, + bidirectional=self.bidirectional, + split_gates=split_gates, + **factory_kwargs, + ) + ] + layers.extend( + _LSTMLayer( + self.hidden_size, + self.hidden_size, + self.bias, + batch_first=False, + bidirectional=self.bidirectional, + split_gates=split_gates, + **factory_kwargs, + ) + for _ in range(1, num_layers) + ) + self.layers = torch.nn.ModuleList(layers) + + def forward(self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None): + if self.batch_first: + x = x.transpose(0, 1) + + max_batch_size = x.size(1) + num_directions = 2 if self.bidirectional else 1 + if hidden is None: + zeros = torch.zeros( + num_directions, + max_batch_size, + self.hidden_size, + dtype=torch.float, + device=x.device, + ) + zeros.squeeze_(0) + if x.is_quantized: + zeros = torch.quantize_per_tensor( + zeros, scale=1.0, zero_point=0, dtype=x.dtype + ) + hxcx = [(zeros, zeros) for _ in range(self.num_layers)] + else: + hidden_non_opt = torch.jit._unwrap_optional(hidden) + if isinstance(hidden_non_opt[0], Tensor): + hx = hidden_non_opt[0].reshape( + self.num_layers, num_directions, max_batch_size, self.hidden_size + ) + cx = hidden_non_opt[1].reshape( + self.num_layers, num_directions, max_batch_size, self.hidden_size + ) + hxcx = [ + (hx[idx].squeeze(0), cx[idx].squeeze(0)) + for idx in range(self.num_layers) + ] + else: + hxcx = hidden_non_opt + + hx_list = [] + cx_list = [] + for idx, layer in enumerate(self.layers): + x, (h, c) = layer(x, hxcx[idx]) + hx_list.append(torch.jit._unwrap_optional(h)) + cx_list.append(torch.jit._unwrap_optional(c)) + hx_tensor = torch.stack(hx_list) + cx_tensor = torch.stack(cx_list) + + # We are creating another dimension for bidirectional case + # need to collapse it + hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) + cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) + + if self.batch_first: + x = x.transpose(0, 1) + + return x, (hx_tensor, cx_tensor) + + def _get_name(self): + return "QuantizableLSTM" + + @classmethod + def from_float(cls, other, qconfig=None, split_gates=False): + assert isinstance(other, cls._FLOAT_MODULE) + assert hasattr(other, "qconfig") or qconfig + observed = cls( + other.input_size, + other.hidden_size, + other.num_layers, + other.bias, + other.batch_first, + other.dropout, + other.bidirectional, + split_gates=split_gates, + ) + observed.qconfig = getattr(other, "qconfig", qconfig) + for idx in range(other.num_layers): + observed.layers[idx] = _LSTMLayer.from_float( + other, idx, qconfig, batch_first=False, split_gates=split_gates + ) + + # Prepare the model + if other.training: + observed.train() + observed = torch.ao.quantization.prepare_qat(observed, inplace=True) + else: + observed.eval() + observed = torch.ao.quantization.prepare(observed, inplace=True) + return observed + + @classmethod + def from_observed(cls, other): + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-quantizable LSTM module. Please, see " + "the examples on quantizable LSTMs." + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77e97d8595282f3d69963ee129fa473249e3ae29 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__init__.py @@ -0,0 +1,39 @@ +from . import functional +from .modules import * # noqa: F403 +from .modules import MaxPool2d + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d36ee34e43fa1e3e12d920a4b22de747f82cebe0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdfe34ad01bea39cba24d73884685ee0048dee5f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47cc08eb00251604374adc42c561eabb92b5375a Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..969fd6f121f5ddb72ed2e8e158e3ee7e990cfd0c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__init__.py @@ -0,0 +1,26 @@ +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .linear import Linear +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell + + +__all__ = [ + "Linear", + "LSTM", + "GRU", + "LSTMCell", + "RNNCell", + "GRUCell", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3741a660c9d60ee8c506da285e8e8c614ba43eb7 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66190c99eaf721652dcd452039213f84aadcf6ec Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..817326087a71942b34c37229a1e90ba8f7142621 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcaf0d910bb634bbe1cae2f6777bf4008b80fd6d Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..8855ccfdbfe60d0c37a579799f1ca65be166cdfa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -0,0 +1,523 @@ +# mypy: allow-untyped-defs +r"""Dynamically quantized convolution modules.""" + +import warnings +from typing import ClassVar, Optional + +import torch +import torch.ao.nn.quantized as nnq +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch._ops import ops +from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding +from torch.nn.common_types import _size_1_t +from torch.nn.modules.utils import _pair, _single, _triple + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class Conv1d(nnq.Conv1d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv1d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 100) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + reduce_range=True, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv1d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range) + + +class Conv2d(nnq.Conv2d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv2d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module " + "has poor numerical accuracy and its use is not recommended" + ) + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv2d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv2d_dynamic(input, self._packed_params, reduce_range) + + +class Conv3d(nnq.Conv3d): + r"""A dynamically quantized conv module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv3d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) + >>> input = torch.randn(20, 16, 56, 56, 56) + >>> output = m(input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConv3d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv3d_dynamic(input, self._packed_params, reduce_range) + + +class ConvTranspose1d(nnq.ConvTranspose1d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose1d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose1d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose1d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + return torch.ops.quantized.conv_transpose1d_dynamic( + input, self._packed_params, reduce_range + ) + + +class ConvTranspose2d(nnq.ConvTranspose2d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose2d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With square kernels and equal stride + >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose2d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return ops.quantized.conv_transpose2d_dynamic( + input, self._packed_params, reduce_range + ) + + +class ConvTranspose3d(nnq.ConvTranspose3d): + r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> # xdoctest: +SKIP + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + warnings.warn( + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950 + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "DynamicQuantizedConvTranspose3d" + + def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor: + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d_dynamic( + input, self._packed_params, reduce_range + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..0faaf62cedb5047c3a595f54433d34caa20c4a2e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/linear.py @@ -0,0 +1,165 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized as nnq +from torch.ao.nn.quantized.modules.utils import _quantize_weight + + +__all__ = [ + "Linear", +] + + +class Linear(nnq.Linear): + r""" + A dynamic quantized linear module with floating point tensor as inputs and outputs. + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. + + Similar to :class:`torch.nn.Linear`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module which are of + shape :math:`(\text{out\_features}, \text{in\_features})`. + bias (Tensor): the non-learnable floating point bias of the module of shape + :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, + the values are initialized to zero. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.quantized.dynamic.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + # version used in this class is different from the parent class nnq.Linear + _version = 4 + + def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias_, dtype=dtype) + # We don't muck around with buffers or attributes or anything here + # to keep the module simple. *everything* is simply a Python attribute. + # Serialization logic is explicitly handled in the below serialization and + # deserialization modules + self.version = 4 + + def forward(self, x): + # Note that we can handle self.bias == None case. + if self._packed_params.dtype == torch.qint8: + if self.version is None or self.version < 4: + Y = torch.ops.quantized.linear_dynamic( + x, self._packed_params._packed_params + ) + else: + Y = torch.ops.quantized.linear_dynamic( + x, self._packed_params._packed_params, reduce_range=True + ) + elif self._packed_params.dtype == torch.float16: + Y = torch.ops.quantized.linear_dynamic_fp16( + x, self._packed_params._packed_params + ) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + return Y.to(x.dtype) + + def _get_name(self): + return "DynamicQuantizedLinear" + + def extra_repr(self): + extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}" + if self._packed_params.dtype == torch.qint8: + extra_repr_str += f", qscheme={self.weight().qscheme()}" + return extra_repr_str + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + self.version = version + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a dynamic quantized module from a float module or qparams_dict + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + float_modules = [ + torch.nn.Linear, + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + torch.ao.nn.intrinsic.modules.fused.LinearReLU, + torch.ao.nn.qat.dynamic.Linear, + ] + + assert type(mod) in float_modules, ( + "nn.quantized.dynamic.Linear.from_float only works for one of" + + str([float_mod.__name__ for float_mod in float_modules]) + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + if type(mod) == nni.LinearReLU: + mod = mod[0] + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer = default_dynamic_qconfig.weight() + dtype = weight_observer.dtype + assert dtype in [torch.qint8, torch.float16], ( + "The only supported dtypes for " + f"dynamic quantized linear are qint8 and float16 got: {dtype}" + ) + weight_observer(mod.weight) + if dtype == torch.qint8: + qweight = _quantize_weight(mod.weight.float(), weight_observer) + elif dtype == torch.float16: + qweight = mod.weight.float() + else: + raise RuntimeError( + "Unsupported dtype specified for dynamic quantized Linear!" + ) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear.set_weight_bias(qweight, mod.bias) + return qlinear + + @classmethod + def from_reference(cls, ref_qlinear): # type: ignore[override] + """Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized + module + Args: + ref_qlinear (Module): a reference quantized module, either produced by + torch.ao.quantization functions or provided by the user + """ + qlinear = cls( + ref_qlinear.in_features, + ref_qlinear.out_features, + dtype=ref_qlinear.weight_dtype, + ) + qweight = ref_qlinear.get_quantized_weight() + bias = ref_qlinear.bias + qlinear.set_weight_bias(qweight, bias) + return qlinear diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..10db59aafbf7ee638ead46e55ebaeff82c2e049b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -0,0 +1,1363 @@ +# mypy: allow-untyped-defs +import numbers +import warnings +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch import Tensor # noqa: F401 +from torch._jit_internal import Dict, List, Optional, Tuple, Union # noqa: F401 +from torch.ao.nn.quantized.modules.utils import _quantize_weight +from torch.nn.utils.rnn import PackedSequence + + +__all__ = [ + "pack_weight_bias", + "PackedParameter", + "RNNBase", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", + "apply_permutation", +] + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +@deprecated( + "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", + category=FutureWarning, +) +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return _apply_permutation(tensor, permutation, dim) + + +def pack_weight_bias(qweight, bias, dtype): + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + packed_weight = torch.ops.quantized.linear_prepack(qweight, bias) + + return packed_weight + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight = torch.ops.quantized.linear_prepack_fp16(qweight, bias) + + return packed_weight + + +class PackedParameter(torch.nn.Module): + def __init__(self, param): + super().__init__() + self.param = param + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "param"] = self.param + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.param = state_dict[prefix + "param"] + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class RNNBase(torch.nn.Module): + _FLOAT_MODULE = nn.RNNBase + + _version = 2 + + def __init__( + self, + mode, + input_size, + hidden_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False, + dtype=torch.qint8, + ): + super().__init__() + + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.dtype = dtype + self.version = 2 + self.training = False + num_directions = 2 if bidirectional else 1 + + # "type: ignore" is required since ints and Numbers are not fully comparable + # https://github.com/python/mypy/issues/8566 + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 # type: ignore[operator] + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) + if dropout > 0 and num_layers == 1: # type: ignore[operator] + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}" + ) + + if mode == "LSTM": + gate_size = 4 * hidden_size + elif mode == "GRU": + gate_size = 3 * hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + _all_weight_values = [] + for layer in range(num_layers): + for _ in range(num_directions): + layer_input_size = ( + input_size if layer == 0 else hidden_size * num_directions + ) + + w_ih = torch.randn(gate_size, layer_input_size).to(torch.float) + w_hh = torch.randn(gate_size, hidden_size).to(torch.float) + b_ih = torch.randn(gate_size).to(torch.float) + b_hh = torch.randn(gate_size).to(torch.float) + if dtype == torch.qint8: + w_ih = torch.quantize_per_tensor( + w_ih, scale=0.1, zero_point=0, dtype=torch.qint8 + ) + w_hh = torch.quantize_per_tensor( + w_hh, scale=0.1, zero_point=0, dtype=torch.qint8 + ) + packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) + if self.version is None or self.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh, True + ) + ) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + + _all_weight_values.append(PackedParameter(cell_params)) + self._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + def _get_name(self): + return "DynamicQuantizedRNN" + + def extra_repr(self): + s = "{input_size}, {hidden_size}" + if self.num_layers != 1: + s += ", num_layers={num_layers}" + if self.bias is not True: + s += ", bias={bias}" + if self.batch_first is not False: + s += ", batch_first={batch_first}" + if self.dropout != 0: + s += ", dropout={dropout}" + if self.bidirectional is not False: + s += ", bidirectional={bidirectional}" + return s.format(**self.__dict__) + + def __repr__(self): + # We don't want to show `ModuleList` children, hence custom + # `__repr__`. This is the same as nn.Module.__repr__, except the check + # for the `PackedParameter` and `nn.ModuleList`. + # You should still override `extra_repr` to add more info. + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + if isinstance(module, (PackedParameter, nn.ModuleList)): + continue + mod_str = repr(module) + mod_str = nn.modules.module._addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) + if self.input_size != input.size(-1): + raise RuntimeError( + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) + + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + self.check_hidden_size( + hidden, expected_hidden_size, msg="Expected hidden size {}, got {}" + ) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + self.version = version + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def set_weight_bias(self, weight_bias_dict): + def weight_bias_name(ihhh, layer, suffix): + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" + return weight_name, bias_name + + num_directions = 2 if self.bidirectional else 1 + # TODO: dedup with __init__ of RNNBase + _all_weight_values = [] + for layer in range(self.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix) + w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix) + w_ih = weight_bias_dict[w_ih_name] + b_ih = weight_bias_dict[b_ih_name] + w_hh = weight_bias_dict[w_hh_name] + b_hh = weight_bias_dict[b_hh_name] + if w_ih.dtype == torch.qint8: + packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) + if self.version is None or self.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, b_ih, b_hh, True + ) + ) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) + packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + + _all_weight_values.append(PackedParameter(cell_params)) + self._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) in { + torch.nn.LSTM, + torch.nn.GRU, + }, "nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU" + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer_method = mod.qconfig.weight + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer_method = default_dynamic_qconfig.weight + + dtype = weight_observer_method().dtype + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError( + f"Unsupported dtype for dynamic RNN quantization: {dtype}" + ) + # RNNBase can be either LSTM or GRU + qRNNBase: Union[LSTM, GRU] + if mod.mode == "LSTM": + qRNNBase = LSTM( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + dtype, + ) + elif mod.mode == "GRU": + qRNNBase = GRU( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + dtype, + ) + else: + raise NotImplementedError( + "Only LSTM/GRU is supported for QuantizedRNN for now" + ) + + num_directions = 2 if mod.bidirectional else 1 + + assert mod.bias + + _all_weight_values = [] + for layer in range(qRNNBase.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + + def retrieve_weight_bias(ihhh): + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" + weight = getattr(mod, weight_name) + bias = getattr(mod, bias_name) + return weight, bias + + weight_ih, bias_ih = retrieve_weight_bias("ih") + weight_hh, bias_hh = retrieve_weight_bias("hh") + + if dtype == torch.qint8: + + def quantize_and_pack(w, b): + weight_observer = weight_observer_method() + weight_observer(w) + qweight = _quantize_weight(w.float(), weight_observer) + packed_weight = torch.ops.quantized.linear_prepack(qweight, b) + return packed_weight + + packed_ih = quantize_and_pack(weight_ih, bias_ih) + packed_hh = quantize_and_pack(weight_hh, bias_hh) + if qRNNBase.version is None or qRNNBase.version < 2: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, bias_ih, bias_hh + ) + ) + else: + cell_params = ( + torch.ops.quantized.make_quantized_cell_params_dynamic( + packed_ih, packed_hh, bias_ih, bias_hh, True + ) + ) + + elif dtype == torch.float16: + packed_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih.float(), bias_ih + ) + packed_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh.float(), bias_hh + ) + + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + else: + raise RuntimeError( + "Unsupported dtype specified for dynamic quantized LSTM!" + ) + + _all_weight_values.append(PackedParameter(cell_params)) + qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values) + + return qRNNBase + + def _weight_bias(self): + # Returns a dict of weights and biases + weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} + count = 0 + num_directions = 2 if self.bidirectional else 1 + for layer in range(self.num_layers): + for direction in range(num_directions): + suffix = "_reverse" if direction == 1 else "" + key_name1 = f"weight_ih_l{layer}{suffix}" + key_name2 = f"weight_hh_l{layer}{suffix}" + # packed weights are part of torchbind class, CellParamsSerializationType + # Within the packed weight class, the weight and bias are accessible as Tensors + packed_weight_bias = self._all_weight_values[ # type: ignore[index] + count + ].param.__getstate__()[0][4] + weight_bias_dict["weight"][key_name1] = packed_weight_bias[ + 0 + ].__getstate__()[0][0] + weight_bias_dict["weight"][key_name2] = packed_weight_bias[ + 1 + ].__getstate__()[0][0] + key_name1 = f"bias_ih_l{layer}{suffix}" + key_name2 = f"bias_hh_l{layer}{suffix}" + weight_bias_dict["bias"][key_name1] = packed_weight_bias[ + 0 + ].__getstate__()[0][1] + weight_bias_dict["bias"][key_name2] = packed_weight_bias[ + 1 + ].__getstate__()[0][1] + count = count + 1 + return weight_bias_dict + + def get_weight(self): + return self._weight_bias()["weight"] + + def get_bias(self): + return self._weight_bias()["bias"] + + +class LSTM(RNNBase): + r""" + A dynamic quantized LSTM module with floating point tensor as inputs and outputs. + We adopt the same interface as `torch.nn.LSTM`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + """ + + _FLOAT_MODULE = nn.LSTM + + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + def _get_name(self): + return "DynamicQuantizedLSTM" + + def forward_impl( + self, + input: Tensor, + hx: Optional[tuple[Tensor, Tensor]], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (zeros, zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + + _all_params = [m.param for m in self._all_weight_values] + if batch_sizes is None: + result = torch.quantized_lstm( + input, + hx, + _all_params, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + self.batch_first, + dtype=self.dtype, + use_dynamic=True, + ) + else: + result = torch.quantized_lstm( + input, + batch_sizes, + hx, + _all_params, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + dtype=self.dtype, + use_dynamic=True, + ) + output = result[0] + hidden = result[1:] + + return output, hidden + + @torch.jit.export + def forward_tensor( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.export + def forward_packed( + self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + + output_, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + # "type: ignore" is required due to issue #43072 + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + # "type: ignore" is required due to issue #43072 + def check_forward_args( # type: ignore[override] + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size( + hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}" + ) + self.check_hidden_size( + hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}" + ) + + @torch.jit.ignore + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " + "exists in LSTM, may need to relax the assumption to support the use case" + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.num_layers, + ref_mod.bias, + ref_mod.batch_first, + ref_mod.dropout, + ref_mod.bidirectional, + # assuming there is layer 0, which should be OK + ref_mod.weight_ih_l0_dtype, + ) + qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) + return qmod + + +class GRU(RNNBase): + r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + + + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features + of the input sequence. The input can also be a packed variable length + sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` + for details. + - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the initial hidden state for each element in the batch. + Defaults to zero if not provided. If the RNN is bidirectional, + num_directions should be 2, else it should be 1. + + Outputs: output, h_n + - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor + containing the output features h_t from the last layer of the GRU, + for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been + given as the input, the output will also be a packed sequence. + For the unpacked case, the directions can be separated + using ``output.view(seq_len, batch, num_directions, hidden_size)``, + with forward and backward being direction `0` and `1` respectively. + + Similarly, the directions can be separated in the packed case. + - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the hidden state for `t = seq_len` + + Like *output*, the layers can be separated using + ``h_n.view(num_layers, num_directions, batch, hidden_size)``. + + Shape: + - Input1: :math:`(L, N, H_{in})` tensor containing input features where + :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. + - Input2: :math:`(S, N, H_{out})` tensor + containing the initial hidden state for each element in the batch. + :math:`H_{out}=\text{hidden\_size}` + Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` + If the RNN is bidirectional, num_directions should be 2, else it should be 1. + - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` + - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state + for each element in the batch + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. + In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the + previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix + `W` and addition of bias: + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) + \end{aligned} + + This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) + \end{aligned} + + This implementation differs on purpose for efficiency. + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + _FLOAT_MODULE = nn.GRU + + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + + def __init__(self, *args, **kwargs): + super().__init__("GRU", *args, **kwargs) + + def _get_name(self): + return "DynamicQuantizedGRU" + + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size( + hidden, expected_hidden_size, "Expected hidden size {}, got {}" + ) + + def forward_impl( + self, + input: Tensor, + hx: Optional[Tensor], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = zeros + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + + _all_params = [m.param for m in self._all_weight_values] + if batch_sizes is None: + result = torch.quantized_gru( + input, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = torch.quantized_gru( + input, + batch_sizes, + hx, + _all_params, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + return output, hidden + + @torch.jit.export + def forward_tensor( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.export + def forward_packed( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> tuple[PackedSequence, Tensor]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + output_, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + @torch.jit.ignore + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " + "exists in LSTM, may need to relax the assumption to support the use case" + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.num_layers, + ref_mod.bias, + ref_mod.batch_first, + ref_mod.dropout, + ref_mod.bidirectional, + # assuming there is layer 0, which should be OK + ref_mod.weight_ih_l0_dtype, + ) + qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) + return qmod + + +class RNNCellBase(torch.nn.Module): + # _FLOAT_MODULE = nn.CellRNNBase + __constants__ = ["input_size", "hidden_size", "bias"] + + def __init__( + self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8 + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_dtype = dtype + if bias: + self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) + self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float) + weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float) + if dtype == torch.qint8: + weight_ih = torch.quantize_per_tensor( + weight_ih, scale=1, zero_point=0, dtype=torch.qint8 + ) + weight_hh = torch.quantize_per_tensor( + weight_hh, scale=1, zero_point=0, dtype=torch.qint8 + ) + + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + packed_weight_ih = torch.ops.quantized.linear_prepack( + weight_ih, self.bias_ih + ) + packed_weight_hh = torch.ops.quantized.linear_prepack( + weight_hh, self.bias_hh + ) + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih, self.bias_ih + ) + packed_weight_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh, self.bias_hh + ) + + self._packed_weight_ih = packed_weight_ih + self._packed_weight_hh = packed_weight_hh + + def _get_name(self): + return "DynamicQuantizedRNNBase" + + def extra_repr(self): + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__) + + def check_forward_input(self, input): + if input.size(1) != self.input_size: + raise RuntimeError( + f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}" + ) + + def check_forward_hidden( + self, input: Tensor, hx: Tensor, hidden_label: str = "" + ) -> None: + if input.size(0) != hx.size(0): + raise RuntimeError( + f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) + + if hx.size(1) != self.hidden_size: + raise RuntimeError( + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) in { + torch.nn.LSTMCell, + torch.nn.GRUCell, + torch.nn.RNNCell, + }, ( + "nn.quantized.dynamic.RNNCellBase.from_float \ + only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell" + ) + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer_method = mod.qconfig.weight + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer_method = default_dynamic_qconfig.weight + + dtype = weight_observer_method().dtype + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError( + f"Unsupported dtype for dynamic RNN quantization: {dtype}" + ) + + qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell] + + if type(mod) == torch.nn.LSTMCell: + qRNNCellBase = LSTMCell( + mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype + ) + elif type(mod) == torch.nn.GRUCell: + qRNNCellBase = GRUCell( + mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype + ) + elif type(mod) == torch.nn.RNNCell: + qRNNCellBase = RNNCell( + mod.input_size, + mod.hidden_size, + bias=mod.bias, + nonlinearity=mod.nonlinearity, + dtype=dtype, + ) + else: + raise NotImplementedError( + "Only LSTMCell, GRUCell and RNNCell \ + are supported for QuantizedRNN for now" + ) + + assert mod.bias + + def _observe_and_quantize_weight(weight): + if dtype == torch.qint8: + weight_observer = weight_observer_method() + weight_observer(weight) + qweight = _quantize_weight(weight.float(), weight_observer) + return qweight + else: + return weight.float() + + qRNNCellBase._packed_weight_ih = pack_weight_bias( + _observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype + ) + qRNNCellBase._packed_weight_hh = pack_weight_bias( + _observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype + ) + return qRNNCellBase + + @classmethod + def from_reference(cls, ref_mod): + assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih " + "exists in reference module, may need to relax the assumption to support the use case" + if hasattr(ref_mod, "nonlinearity"): + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.bias, + ref_mod.nonlinearity, + dtype=ref_mod.weight_ih_dtype, + ) + else: + qmod = cls( + ref_mod.input_size, + ref_mod.hidden_size, + ref_mod.bias, + dtype=ref_mod.weight_ih_dtype, + ) + weight_bias_dict = { + "weight": { + "weight_ih": ref_mod.get_quantized_weight_ih(), + "weight_hh": ref_mod.get_quantized_weight_hh(), + }, + "bias": { + "bias_ih": ref_mod.bias_ih, + "bias_hh": ref_mod.bias_hh, + }, + } + qmod.set_weight_bias(weight_bias_dict) + return qmod + + def _weight_bias(self): + # Returns a dict of weights and biases + weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} + w1, b1 = self._packed_weight_ih.__getstate__()[0] + w2, b2 = self._packed_weight_hh.__getstate__()[0] + # TODO: these can be simplified to one level? e.g. using weight_ih as key + # directly + weight_bias_dict["weight"]["weight_ih"] = w1 + weight_bias_dict["weight"]["weight_hh"] = w2 + weight_bias_dict["bias"]["bias_ih"] = b1 + weight_bias_dict["bias"]["bias_hh"] = b2 + return weight_bias_dict + + def get_weight(self): + return self._weight_bias()["weight"] + + def get_bias(self): + return self._weight_bias()["bias"] + + def set_weight_bias(self, weight_bias_dict): + # TODO: these can be simplified to one level? e.g. using weight_ih as key + # directly + self._packed_weight_ih = pack_weight_bias( + weight_bias_dict["weight"]["weight_ih"], + weight_bias_dict["bias"]["bias_ih"], + self.weight_dtype, + ) + self._packed_weight_hh = pack_weight_bias( + weight_bias_dict["weight"]["weight_hh"], + weight_bias_dict["bias"]["bias_hh"], + self.weight_dtype, + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "_packed_weight_ih"] = self._packed_weight_ih + destination[prefix + "_packed_weight_hh"] = self._packed_weight_hh + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self._packed_weight_ih = state_dict.pop(prefix + "_packed_weight_ih") + self._packed_weight_hh = state_dict.pop(prefix + "_packed_weight_hh") + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class RNNCell(RNNCellBase): + r"""An Elman RNN cell with tanh or ReLU non-linearity. + A dynamic quantized RNNCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.RNNCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] + + def __init__( + self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8 + ): + super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype) + self.nonlinearity = nonlinearity + + def _get_name(self): + return "DynamicQuantizedRNNCell" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + if self.nonlinearity == "tanh": + ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + return ret + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell. + + A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.LSTMCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, num_chunks=4, **kwargs) # type: ignore[misc] + + def _get_name(self): + return "DynamicQuantizedLSTMCell" + + def forward( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + self.check_forward_input(input) + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + self.check_forward_hidden(input, hx[0], "[0]") + self.check_forward_hidden(input, hx[1], "[1]") + return torch.ops.quantized.quantized_lstm_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell + + A dynamic quantized GRUCell module with floating point tensor as inputs and outputs. + Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`, + please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation. + + Examples:: + + >>> # xdoctest: +SKIP + >>> rnn = nn.GRUCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8): + super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype) + + def _get_name(self): + return "DynamicQuantizedGRUCell" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + return torch.ops.quantized.quantized_gru_cell_dynamic( + input, + hx, + self._packed_weight_ih, + self._packed_weight_hh, + self.bias_ih, + self.bias_hh, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/functional.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..51a2f4905c257c43f93d2b0661d7903a31ec1759 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/functional.py @@ -0,0 +1,779 @@ +# mypy: allow-untyped-defs +r"""Functional interface (quantized).""" + +import warnings +from typing import Optional + +import torch +from torch import Tensor +from torch.jit.annotations import BroadcastingList2 +from torch.nn.modules.utils import _pair, _triple + +from .modules.utils import _pair_from_first + + +# Although some of the functions and docstrings are mirrored from the torch.nn, +# we want to have them here for future changes. + +__all__ = [ + "avg_pool2d", + "avg_pool3d", + "adaptive_avg_pool2d", + "adaptive_avg_pool3d", + "conv1d", + "conv2d", + "conv3d", + "interpolate", + "linear", + "max_pool1d", + "max_pool2d", + "celu", + "leaky_relu", + "hardtanh", + "hardswish", + "threshold", + "elu", + "hardsigmoid", + "clamp", + "upsample", + "upsample_bilinear", + "upsample_nearest", +] + + +def avg_pool2d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + r""" + Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size + :math:`sH \times sW` steps. The number of output features is equal to the number of + input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AvgPool2d` for details and output shape. + + Args: + input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!") + return torch.nn.functional.avg_pool2d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + +def avg_pool3d( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + r""" + Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size + :math:`sD \times sH \times sW` steps. The number of output features is equal to the number of + input planes. + + .. note:: The input quantization parameters propagate to the output. + + Args: + input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kD, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sD, sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padD, padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!") + return torch.nn.functional.avg_pool3d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r""" + Applies a 2D adaptive average pooling over a quantized input signal composed + of several quantized input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if not input.is_quantized: + raise ValueError( + "Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!" + ) + return torch.nn.functional.adaptive_avg_pool2d(input, output_size) + + +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r""" + Applies a 3D adaptive average pooling over a quantized input signal composed + of several quantized input planes. + + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if not input.is_quantized: + raise ValueError( + "Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!" + ) + return torch.nn.functional.adaptive_avg_pool3d(input, output_size) + + +def conv1d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 1D convolution over a quantized 1D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv1d` for details and output shape. + + Args: + input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)` + bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sW,)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padW,)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dW,)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(33, 16, 3, dtype=torch.float) + >>> inputs = torch.randn(20, 16, 50, dtype=torch.float) + >>> bias = torch.randn(33, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + stride = _pair_from_first(stride) + padding = _pair_from_first(padding) + dilation = _pair_from_first(dilation) + + packed_params = torch.ops.quantized.conv1d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point) + + +def conv2d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 2D convolution over a quantized 2D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape. + + Args: + input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` + bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float) + >>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float) + >>> bias = torch.randn(8, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + + packed_params = torch.ops.quantized.conv2d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point) + + +def conv3d( + input, + weight, + bias, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + scale=1.0, + zero_point=0, + dtype=torch.quint8, +): + r""" + Applies a 3D convolution over a quantized 3D input composed of several input + planes. + + See :class:`~torch.ao.nn.quantized.Conv3d` for details and output shape. + + Args: + input: quantized input tensor of shape + :math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)` + weight: quantized filters of shape + :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)` + bias: **non-quantized** bias tensor of shape + :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sD, sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a + single number or a tuple `(padD, padH, padW)`. Default: 0 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dD, dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be + divisible by the number of groups. Default: 1 + padding_mode: the padding mode to use. Only "zeros" is supported for + quantized convolution at the moment. Default: "zeros" + scale: quantization scale for the output. Default: 1.0 + zero_point: quantization zero_point for the output. Default: 0 + dtype: quantization data type to use. Default: ``torch.quint8`` + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> from torch.ao.nn.quantized import functional as qF + >>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float) + >>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float) + >>> bias = torch.randn(8, dtype=torch.float) + >>> + >>> scale, zero_point = 1.0, 0 + >>> dtype_inputs = torch.quint8 + >>> dtype_filters = torch.qint8 + >>> + >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) + >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) + >>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) + """ # noqa: E501 + if padding_mode != "zeros": + raise NotImplementedError("Only zero-padding is supported!") + if input.dtype != torch.quint8: + raise NotImplementedError( + "Only torch.quint8 is supported for activation tensor!" + ) + if weight.dtype != torch.qint8: + raise NotImplementedError("Only torch.qint8 is supported for weight tensor!") + if input.ndim != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + + packed_params = torch.ops.quantized.conv3d_prepack( + weight, bias, stride, padding, dilation, groups + ) + return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point) + + +def interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + r"""Down/up samples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D/3D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.interpolate' must be quantized!") + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + +def linear( + input: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + scale: Optional[float] = None, + zero_point: Optional[int] = None, +) -> Tensor: + r""" + Applies a linear transformation to the incoming quantized data: + :math:`y = xA^T + b`. + See :class:`~torch.ao.nn.quantized.Linear` + + .. note:: + + Current implementation packs weights on every call, which has penalty on performance. + If you want to avoid the overhead, use :class:`~torch.ao.nn.quantized.Linear`. + + Args: + input (Tensor): Quantized input of type `torch.quint8` + weight (Tensor): Quantized weight of type `torch.qint8` + bias (Tensor): None or fp32 bias of type `torch.float` + scale (double): output scale. If None, derived from the input scale + zero_point (long): output zero point. If None, derived from the input zero_point + + Shape: + - Input: :math:`(N, *, in\_features)` where `*` means any number of + additional dimensions + - Weight: :math:`(out\_features, in\_features)` + - Bias: :math:`(out\_features)` + - Output: :math:`(N, *, out\_features)` + """ + if scale is None: + scale = input.q_scale() + if zero_point is None: + zero_point = input.q_zero_point() + _packed_params = torch.ops.quantized.linear_prepack(weight, bias) + return torch.ops.quantized.linear(input, _packed_params, scale, zero_point) + + +def max_pool1d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + r"""Applies a 1D max pooling over a quantized input signal composed of + several quantized input planes. + + .. note:: The input quantization parameters are propagated to the output. + + See :class:`~torch.ao.nn.quantized.MaxPool1d` for details. + """ + if return_indices: + raise NotImplementedError("return_indices is not yet implemented!") + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.nn.functional.max_pool1d( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + + +def max_pool2d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + r"""Applies a 2D max pooling over a quantized input signal composed of + several quantized input planes. + + .. note:: The input quantization parameters are propagated to the output. + + See :class:`~torch.ao.nn.quantized.MaxPool2d` for details. + """ + if return_indices: + raise NotImplementedError("return_indices is not yet implemented!") + if stride is None: + stride = torch.jit.annotate(list[int], []) + return torch.nn.functional.max_pool2d( + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + + +def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor: + r"""celu(input, scale, zero_point, alpha=1.) -> Tensor + + Applies the quantized CELU function element-wise. + + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1)) + + Args: + input: quantized input + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.celu' must be quantized!") + return torch.ops.quantized.celu(input, scale, zero_point, alpha) + + +def leaky_relu( + input: Tensor, + negative_slope: float = 0.01, + inplace: bool = False, + scale: Optional[float] = None, + zero_point: Optional[int] = None, +): + r""" + Quantized version of the. + leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor + + Applies element-wise, + :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` + + Args: + input: Quantized input + negative_slope: The slope of the negative input + inplace: Inplace modification of the input tensor + scale, zero_point: Scale and zero point of the output tensor. + + See :class:`~torch.nn.LeakyReLU` for more details. + """ + if scale is not None and zero_point is not None: + assert not inplace, "Cannot rescale with `inplace`" + output = torch._empty_affine_quantized( + input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype + ) + torch._C._nn.leaky_relu(input, negative_slope, out=output) + return output + if inplace: + result = torch._C._nn.leaky_relu_(input, negative_slope) + else: + result = torch._C._nn.leaky_relu(input, negative_slope) + return result + + +def hardtanh( + input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False +) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`.""" + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardtanh' must be quantized!") + if inplace: + return torch._C._nn.hardtanh_(input, min_val, max_val) + return torch._C._nn.hardtanh(input, min_val, max_val) + + +def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`. + + Args: + input: quantized input + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardswish' must be quantized!") + return torch._ops.ops.quantized.hardswish(input, scale, zero_point) + + +def threshold(input: Tensor, threshold: float, value: float) -> Tensor: + r"""Applies the quantized version of the threshold function element-wise: + + .. math:: + x = \begin{cases} + x & \text{if~} x > \text{threshold} \\ + \text{value} & \text{otherwise} + \end{cases} + + See :class:`~torch.nn.Threshold` for more details. + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.threshold' must be quantized!") + if threshold is None: + raise ValueError("Input to 'threshold' must be specified!") + if value is None: + raise ValueError("Input to 'value' must be specified!") + return torch._ops.ops.quantized.threshold(input, threshold, value) + + +def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.elu`. + + Args: + input: quantized input + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + alpha: the alpha constant + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.elu' must be quantized!") + return torch.ops.quantized.elu(input, scale, zero_point, alpha) + + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: + r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`.""" + if not input.is_quantized: + raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!") + if inplace: + return torch._C._nn.hardsigmoid_(input) # type: ignore[attr-defined] + return torch._C._nn.hardsigmoid(input) + + +def clamp(input: Tensor, min_: float, max_: float) -> Tensor: + r"""float(input, min\_, max\_) -> Tensor + + Applies the clamp function element-wise. + See :class:`~torch.ao.nn.quantized.clamp` for more details. + + Args: + input: quantized input + min_: minimum value for clamping + max_: maximum value for clamping + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.clamp' must be quantized!") + return torch.clamp(input, min_, max_) + + +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + r"""Upsamples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(...)``. + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): quantized input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`bilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + """ + warnings.warn( + "nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode, align_corners) + + +def upsample_bilinear(input, size=None, scale_factor=None): + r"""Upsamples the input, using bilinear upsampling. + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with + ``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int]): output spatial size. + scale_factor (int or Tuple[int, int]): multiplier for spatial size + """ + # DeprecationWarning is ignored by default + warnings.warn( + "nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + +def upsample_nearest(input, size=None, scale_factor=None): + r"""Upsamples the input, using nearest neighbours' pixel values. + + .. warning:: + This function is deprecated in favor of + :func:`torch.ao.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial + size. + scale_factor (int): multiplier for spatial size. Has to be an integer. + """ + # DeprecationWarning is ignored by default + warnings.warn( + "nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead." + ) + return interpolate(input, size, scale_factor, mode="nearest") diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bad8c49350f56e5e58235570799a8d0968296d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__init__.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +import torch + +# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable` +# packages. However, the `quantizable` package uses "lazy imports" +# to avoid circular dependency. +# Hence we need to include it here to make sure it is resolved before +# they are used in the modules. +import torch.ao.nn.quantizable +from torch.nn.modules.pooling import MaxPool2d + +from .activation import ( + ELU, + Hardswish, + LeakyReLU, + MultiheadAttention, + PReLU, + ReLU6, + Sigmoid, + Softmax, +) +from .batchnorm import BatchNorm2d, BatchNorm3d +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .dropout import Dropout +from .embedding_ops import Embedding, EmbeddingBag +from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional +from .linear import Linear +from .normalization import ( + GroupNorm, + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LayerNorm, +) +from .rnn import LSTM + + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "DeQuantize", + "ELU", + "Embedding", + "EmbeddingBag", + "GroupNorm", + "Hardswish", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "LeakyReLU", + "Linear", + "LSTM", + "MultiheadAttention", + "Quantize", + "ReLU6", + "Sigmoid", + "Softmax", + "Dropout", + "PReLU", + # Wrapper modules + "FloatFunctional", + "FXFloatFunctional", + "QFunctional", +] + + +class Quantize(torch.nn.Module): + r"""Quantizes an incoming tensor + + Args: + `scale`: scale of the output Quantized Tensor + `zero_point`: zero_point of output Quantized Tensor + `dtype`: data type of output Quantized Tensor + `factory_kwargs`: Dictionary of kwargs used for configuring initialization + of internal buffers. Currently, `device` and `dtype` are supported. + Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}` + will initialize internal buffers as type `torch.float64` on the current CUDA device. + Note that `dtype` only applies to floating-point buffers. + + Examples:: + >>> t = torch.tensor([[1., -1.], [1., -1.]]) + >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 + >>> qm = Quantize(scale, zero_point, dtype) + >>> # xdoctest: +SKIP + >>> qt = qm(t) + >>> print(qt) + tensor([[ 1., -1.], + [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2) + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__(self, scale, zero_point, dtype, factory_kwargs=None): + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__() + self.register_buffer("scale", torch.tensor([scale], **factory_kwargs)) + self.register_buffer( + "zero_point", + torch.tensor( + [zero_point], + dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) + self.dtype = dtype + + def forward(self, X): + return torch.quantize_per_tensor( + X, float(self.scale), int(self.zero_point), self.dtype + ) + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + assert hasattr(mod, "activation_post_process") + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Quantize( + scale.float().item(), + zero_point.long().item(), + mod.activation_post_process.dtype, + ) + + def extra_repr(self): + return f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}" + + +class DeQuantize(torch.nn.Module): + r"""Dequantizes an incoming tensor + + Examples:: + >>> input = torch.tensor([[1., -1.], [1., -1.]]) + >>> scale, zero_point, dtype = 1.0, 2, torch.qint8 + >>> qm = Quantize(scale, zero_point, dtype) + >>> # xdoctest: +SKIP + >>> quantized_input = qm(input) + >>> dqm = DeQuantize() + >>> dequantized = dqm(quantized_input) + >>> print(dequantized) + tensor([[ 1., -1.], + [ 1., -1.]], dtype=torch.float32) + """ + + def forward(self, Xq): + return Xq.dequantize() + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + return DeQuantize() diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bc87e5af889544f2e4a4cc3de11865a949f3d1a Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57ab24974a6c1bd4ef734cf314ef073f4ad99d20 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/activation.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..220675237bac4ba99d827d905c2c027750beb406 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..046ec37d04bcc860dc3eaef67206e3ed14d7b5b6 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20554b77eff297536fc0b2ee6ee2ee43ca40ee5e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95a10665d68ba19a51849d838550b40717cf1561 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9790c9f4c8176a3d036d8a75b092ca19b110b712 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13e766e23762ee3e85d682b741baf9b3eebc47af Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1443c4cd14266f461461b5596725f1b14d992b85 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..056dc726af0e897aab44537fcfa290b72fe27fd0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70dc3d18ea7363d04e2eb9adeee7efb0cce143ae Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/activation.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..15b4d36e8b44a700596b124bb0b521c25bd3d569 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/activation.py @@ -0,0 +1,344 @@ +# mypy: allow-untyped-defs +from warnings import warn + +import torch + + +__all__ = [ + "ReLU6", + "Hardswish", + "ELU", + "LeakyReLU", + "Sigmoid", + "Softmax", + "MultiheadAttention", + "PReLU", +] + + +class ReLU6(torch.nn.ReLU): + r"""Applies the element-wise function: + + :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the + zero_point, and :math:`q(6)` is the quantized representation of number 6. + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/ReLU6.png + + Examples:: + + >>> m = nn.quantized.ReLU6() + >>> input = torch.randn(2) + >>> # xdoctest: +SKIP + >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) + >>> output = m(input) + """ + + def __init__(self, inplace=False): + super().__init__(inplace) + self.inplace = inplace + + def forward(self, input): + return torch.ops.quantized.relu6(input, self.inplace) + + def _get_name(self): + return "QuantizedReLU6" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + return ReLU6(mod.inplace) + + +class Hardswish(torch.nn.Hardswish): + r"""This is the quantized version of :class:`~torch.nn.Hardswish`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, scale, zero_point, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.hardswish(input, self.scale, self.zero_point) + + def _get_name(self): + return "QuantizedHardswish" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Hardswish(float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point)) + + +class ELU(torch.nn.ELU): + r"""This is the quantized equivalent of :class:`~torch.nn.ELU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + alpha: the alpha constant + """ + + def __init__(self, scale, zero_point, alpha=1.0): + super().__init__(alpha) + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + return torch.ao.nn.quantized.functional.elu( + input, self.scale, self.zero_point, self.alpha + ) + + def _get_name(self): + return "QuantizedELU" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return ELU(float(scale), int(zero_point), mod.alpha) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.alpha) + + +class LeakyReLU(torch.nn.LeakyReLU): + r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + negative_slope: Controls the angle of the negative slope. Default: 1e-2 + """ + + def __init__( + self, + scale: float, + zero_point: int, + negative_slope: float = 1e-2, + inplace: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(negative_slope, inplace) + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.leaky_relu( + input, self.negative_slope, self.inplace, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedLeakyReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + + +class Sigmoid(torch.nn.Sigmoid): + r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, output_scale: float, output_zero_point: int): + super().__init__() + self.output_scale = output_scale + self.output_zero_point = output_zero_point + + def forward(self, input): + return torch.ops.quantized.sigmoid( + input, self.output_scale, self.output_zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + ( + output_scale, + output_zero_point, + ) = mod.activation_post_process.calculate_qparams() + return cls(float(output_scale), int(output_zero_point)) + + +class Softmax(torch.nn.Softmax): + r"""This is the quantized version of :class:`~torch.nn.Softmax`. + + Args: + dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1). + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, dim=None, scale=1.0, zero_point=0): + super().__init__() + self.dim = dim + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + dim = self.dim + if dim is None: + stacklevel = 3 + # Note: adding the mypy ignore on _get_softmax_dim seems less bad + # than making `_get_softmax_dim` an official API. + dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined] + "softmax", input.dim(), stacklevel + ) + return torch.ops.quantized.softmax(input, dim, self.scale, self.zero_point) + + def _get_name(self): + return "QuantizedSoftmax" + + @staticmethod + def from_float(mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Softmax(mod.dim, float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.dim, float(scale), int(zero_point)) + + +class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): + _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention + + def _get_name(self): + return "QuantizedMultiheadAttention" + + @classmethod + def from_float(cls, other): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-observed MHA module. Please, see " + "the examples on quantizable MHAs." + ) + + @classmethod + def from_observed(cls, other): + converted = torch.ao.quantization.convert( + other, + mapping=None, + inplace=False, + remove_qconfig=True, + convert_custom_config_dict=None, + ) + converted.__class__ = cls + # Remove the parameters for the bias_k and bias_v to quantize them + # TODO: This is a potential source of accuracy drop. + # quantized cat takes the scale and zp of the first + # element, which might lose the precision in the bias_k + # and the bias_v (which are cat'ed with k/v being first). + if converted.bias_k is not None: + bias_k = converted._parameters.pop("bias_k") + sc, zp = torch._choose_qparams_per_tensor(bias_k, reduce_range=False) + bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8) + setattr(converted, "bias_k", bias_k) # noqa: B010 + + if converted.bias_v is not None: + bias_v = converted._parameters.pop("bias_v") + sc, zp = torch._choose_qparams_per_tensor( + bias_k, # type: ignore[possibly-undefined] + reduce_range=False, + ) + bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) + setattr(converted, "bias_v", bias_v) # noqa: B010 + + del converted.in_proj_weight + del converted.in_proj_bias + + return converted + + +class PReLU(torch.nn.Module): + r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + num_parameters: number of parameters: 1, or the number of channels at input. Default: 1 + """ + + def __init__( + self, output_scale: float, output_zero_point: int, num_parameters: int = 1 + ) -> None: + super().__init__() + self.num_parameters = num_parameters + self.scale = output_scale + self.zero_point = output_zero_point + w = torch.randn(num_parameters, dtype=torch.float) + qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8) + self.set_weight(qw) + + def set_weight(self, w: torch.Tensor) -> None: + self.weight = w + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.prelu( + input, self.weight, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedPReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8 + ) + qprelu.set_weight(qweight) + return qprelu + + @classmethod + def from_reference(cls, mod, scale, zero_point): + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8 + ) + qprelu.set_weight(qweight) + return qprelu diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/batchnorm.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..069db116a064b5940cbd86429fb0758399c12c78 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/batchnorm.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.intrinsic as nni + + +__all__ = ["BatchNorm2d", "BatchNorm3d"] + + +class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm): + def __init__( + self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, True, True, **factory_kwargs) + self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs)) + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + activation_post_process = mod.activation_post_process + if type(mod) == cls._NNI_BN_RELU_MODULE: + mod = mod[0] + scale, zero_point = activation_post_process.calculate_qparams() + new_mod = cls(mod.num_features, mod.eps) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + new_mod.running_mean = mod.running_mean + new_mod.running_var = mod.running_var + new_mod.scale = scale + new_mod.zero_point = zero_point + return new_mod + + @classmethod + def from_reference(cls, bn, output_scale, output_zero_point): + qbn = cls( + bn.num_features, + bn.eps, + bn.momentum, + device=bn.weight.device, + dtype=bn.weight.dtype, + ) + qbn.weight = bn.weight + qbn.bias = bn.bias + qbn.running_mean = bn.running_mean + qbn.running_var = bn.running_var + qbn.scale = output_scale + qbn.zero_point = output_zero_point + return qbn + + +class BatchNorm2d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.""" + + _NNI_BN_RELU_MODULE = nni.BNReLU2d + + def __init__( + self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return "QuantizedBatchNorm2d" + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm2d( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return _BatchNorm.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class BatchNorm3d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.""" + + _NNI_BN_RELU_MODULE = nni.BNReLU3d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return "QuantizedBatchNorm3d" + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm3d( + input, + self.weight, + self.bias, + self.running_mean, + self.running_var, + self.eps, + self.scale, + self.zero_point, + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + return _BatchNorm.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/conv.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..907a04898273be67e33a3b68d534577b1c353dee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/conv.py @@ -0,0 +1,1243 @@ +# mypy: allow-untyped-defs +r"""Quantized convolution modules.""" + +from typing import ClassVar, Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.nn as nn +import torch.nn.functional as F +from torch._ops import ops +from torch.nn.common_types import _size_1_t +from torch.nn.modules.utils import _pair, _single, _triple +from torch.nn.utils import fuse_conv_bn_weights + +from .utils import _quantize_weight, WeightedQuantizedModule + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + +_SUPPORTED_PADDING = {"zeros", "reflect"} + + +def _reverse_repeat_padding(padding: list[int]) -> list[int]: + _reversed_padding_repeated_twice: list[int] = [] + N = len(padding) + for idx in range(N): + _reversed_padding_repeated_twice.extend(padding[N - idx - 1] for _ in range(2)) + return _reversed_padding_repeated_twice + + +class _ConvNd(WeightedQuantizedModule): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode="zeros", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + if padding_mode not in _SUPPORTED_PADDING: + raise ValueError( + f"'padding_mode' {padding_mode} is not supported by quantized convolution" + ) + self.padding_mode = padding_mode + # Initialize as NCHW. set_weight will internally transpose to NHWC. + if self.transposed: + weight_shape = [in_channels, out_channels // self.groups] + else: + weight_shape = [out_channels, in_channels // self.groups] + qweight = torch._empty_affine_quantized( + weight_shape + list(kernel_size), + scale=1, + zero_point=0, + dtype=torch.qint8, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + bias_float = ( + torch.zeros( + out_channels, + dtype=torch.float, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) + if bias + else None + ) + + self.set_weight_bias(qweight, bias_float) + self.scale = 1.0 + self.zero_point = 0 + + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + + def extra_repr(self): + s = ( + "{in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}, scale={scale}, zero_point={zero_point}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias() is None: + s += ", bias=False" + return s.format(**self.__dict__) + + # ===== Serialization methods ===== + # The special consideration here is that we have to unpack the weights into + # their regular QTensor form for serialization. Packed weights should not + # live outside the process in which they were created, rather they should be + # derived from the QTensor weight. + # self + # |--- weight : Tensor + # |--- bias : Tensor + # + # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed + # self + # |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + (w, b) = self._weight_bias() + destination[prefix + "weight"] = w + destination[prefix + "bias"] = b + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + @torch.jit.export + def __getstate__(self): + (w, b) = self._weight_bias() + return ( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.transposed, + self.output_padding, + self.groups, + self.padding_mode, + w, + b, + self.scale, + self.zero_point, + self.training, + ) + + # ===== Deserialization methods ===== + # Counterpart to the serialization methods, we must pack the serialized + # QTensor weight into its packed format for use by the FBGEMM ops. + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.set_weight_bias(state_dict[prefix + "weight"], state_dict[prefix + "bias"]) + state_dict.pop(prefix + "weight") + state_dict.pop(prefix + "bias") + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def __setstate__(self, state): + self.in_channels = state[0] + self.out_channels = state[1] + self.kernel_size = state[2] + self.stride = state[3] + self.padding = state[4] + self.dilation = state[5] + self.transposed = state[6] + self.output_padding = state[7] + self.groups = state[8] + self.padding_mode = state[9] + self.set_weight_bias(state[10], state[11]) + self.scale = state[12] + self.zero_point = state[13] + self.training = state[14] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + torch.nn.Module.__init__(new_instance) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def __copy__(self): + return self.__deepcopy__({}) + + @classmethod + def get_qconv(cls, mod, activation_post_process, weight_post_process=None): + r"""Creates a qconv object and returns it.""" + if weight_post_process is None: + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + mod.bias is not None, + mod.padding_mode, + ) + qconv.set_weight_bias(qweight, mod.bias) + if ( + activation_post_process is None + or activation_post_process.dtype == torch.float + ): + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = activation_post_process.calculate_qparams() + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + if hasattr(mod, "weight_fake_quant"): + # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ + # ".from_float only works for " + cls.__QAT_MODULE.__name__ + if type(mod) == cls._NNIQAT_CONV_BN_MODULE: + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + assert hasattr(mod, "activation_post_process"), ( + "Input QAT module must have observer attached" + ) + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + assert type(mod) == cls._FLOAT_MODULE, ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + + " but got:" + + str(type(mod)) + ) + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined." + ) + activation_post_process = ( + None + if not hasattr(mod, "activation_post_process") + else mod.activation_post_process + ) + if type(mod) in [ + cls._NNI_CONV_RELU_MODULE, + cls._NNI_CONV_ADD_MODULE, + cls._NNI_CONV_ADD_RELU_MODULE, + ]: + mod = mod[0] + weight_post_process = mod.qconfig.weight() + return cls.get_qconv(mod, activation_post_process, weight_post_process) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconv.in_channels, + ref_qconv.out_channels, + ref_qconv.kernel_size, # type: ignore[arg-type] + ref_qconv.stride, # type: ignore[arg-type] + ref_qconv.padding, # type: ignore[arg-type] + ref_qconv.dilation, # type: ignore[arg-type] + ref_qconv.groups, + ref_qconv.bias is not None, # type: ignore[arg-type] + ref_qconv.padding_mode, + device=ref_qconv.weight.device, + dtype=ref_qconv.weight.dtype, + ) + qweight = ref_qconv.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconv.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class Conv1d(_ConvNd): + r"""Applies a 1D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv1d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv1d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, + ... dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _single(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv1d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) + return w, b + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != "zeros": + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv1d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv2d(_ConvNd): + r"""Applies a 2D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv2d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE: ClassVar[type[nni.ConvAdd2d]] = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE: ClassVar[type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv2d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups + ) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv2d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +class Conv3d(_ConvNd): + r"""Applies a 3D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv3d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) + >>> input = torch.randn(20, 16, 56, 56, 56) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + ): + assert padding_mode != "reflect", "Conv3d does not support reflection padding" + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConv3d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == "zeros": + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups + ) + else: + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, _triple(0), self.dilation, self.groups + ) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != "zeros": + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad( + input, _reversed_padding_repeated_twice, mode=self.padding_mode + ) + return ops.quantized.conv3d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float( + cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + + +# === Transposed Convolutions === + + +class _ConvTransposeNd(_ConvNd): + _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]] + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + ): + if padding_mode != "zeros": + raise ValueError( + f'Only "zeros" padding mode is supported for {self.__class__.__name__}' + ) + factory_kwargs = {"device": device, "dtype": dtype} + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _input_padding( + self, kernel_size: list[int], dilation: list[int], padding: list[int] + ) -> list[int]: + res = torch.jit.annotate(list[int], []) + for kdx in range(len(kernel_size)): + pad = dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx] + res.append(pad) + return res + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] + r"""Creates a quantized module from a float module or qparams_dict. + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + # derived classes override cls._FLOAT_MODULE attribute + msg = ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] + ) + assert type(mod) == cls._FLOAT_MODULE, msg + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined." + weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr] + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, ( + "Weight observer must have a dtype of qint8" + ) + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls( + mod.in_channels, + mod.out_channels, + mod.kernel_size, # type: ignore[call-arg] + mod.stride, + mod.padding, + mod.output_padding, + mod.groups, + mod.bias is not None, + mod.dilation, + mod.padding_mode, + ) + qconv.set_weight_bias(qweight, mod.bias) + if ( + not hasattr(mod, "activation_post_process") + or mod.activation_post_process.dtype == torch.float + ): + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = mod.activation_post_process.calculate_qparams() # type: ignore[operator, union-attr] + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconvt.in_channels, + ref_qconvt.out_channels, + ref_qconvt.kernel_size, # type: ignore[arg-type] + ref_qconvt.stride, # type: ignore[arg-type] + ref_qconvt.padding, # type: ignore[arg-type] + ref_qconvt.output_padding, # type: ignore[arg-type] + ref_qconvt.groups, + ref_qconvt.bias is not None, # type: ignore[arg-type] + ref_qconvt.dilation, # type: ignore[arg-type] + ref_qconvt.padding_mode, + device=ref_qconvt.weight.device, + dtype=ref_qconvt.weight.dtype, + ) + qweight = ref_qconvt.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconvt.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class ConvTranspose1d(_ConvTransposeNd): + r"""Applies a 1D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose1d`. + + .. note:: Currently only the QNNPACK engine is implemented. + Please, set the `torch.backends.quantized.engine = 'qnnpack'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'qnnpack' + >>> from torch.ao.nn import quantized as nnq + >>> # With square kernels and equal stride + >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose1d]] = nn.ConvTranspose1d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose1d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + return torch.ops.quantized.conv_transpose1d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) + + +class ConvTranspose2d(_ConvTransposeNd): + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose2d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # QNNPACK or FBGEMM as backend + >>> torch.backends.quantized.engine = 'qnnpack' + >>> # With square kernels and equal stride + >>> import torch.ao.nn.quantized as nnq + >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose2d]] = nn.ConvTranspose2d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose2d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv2d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return ops.quantized.conv_transpose2d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) + + +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + .. note:: Currently only the FBGEMM engine is implemented. + Please, set the `torch.backends.quantized.engine = 'fbgemm'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'fbgemm' + >>> from torch.ao.nn import quantized as nnq + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> input = torch.randn(20, 16, 50, 100, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE: ClassVar[type[nn.ConvTranspose3d]] = nn.ConvTranspose3d + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) + + def _get_name(self): + return "QuantizedConvTranspose3d" + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( + w, + b, + self.stride, + self.padding, + self.output_padding, + self.dilation, + self.groups, + ) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d( + input, self._packed_params, self.scale, self.zero_point + ) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override] + return _ConvTransposeNd.from_reference( + cls, ref_qconvt, output_scale, output_zero_point + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/dropout.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..3744ca30d5a49ba92cbb86690f2683af02d594fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/dropout.py @@ -0,0 +1,30 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = ["Dropout"] + + +class Dropout(torch.nn.Dropout): + r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`. + And this is a placeholder to enable models where fp32 tensors + had dropout to work with quantized tensors in train and eval mode. + + Args: + p: probability of an element to be zeroed + inplace: can optionally do the operation in-place. Default: ``False`` + """ + + def forward(self, input): + return input + + def _get_name(self): + return "QuantizedDropout" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return cls(mod.p, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.p, mod.inplace) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c39c8de8ce2ccc1af105964edbcb11f3926ad21d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/embedding_ops.py @@ -0,0 +1,413 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +from torch import Tensor # noqa: F401 +from torch._jit_internal import List, Optional # noqa: F401 + +from .utils import _hide_packed_params_repr, _quantize_weight + + +__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"] + + +class EmbeddingPackedParams(torch.nn.Module): + _version = 1 + + def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): + super().__init__() + self.dtype = dtype + if self.dtype in [torch.quint8, torch.quint4x2]: + scales = torch.ones(num_embeddings, dtype=torch.float) + zero_points = torch.zeros(num_embeddings, dtype=torch.float) + wq = torch._empty_per_channel_affine_quantized( + [num_embeddings, embedding_dim], + scales=scales, + zero_points=zero_points, + axis=0, + dtype=self.dtype, + ) + self.set_weight(wq) + else: + raise NotImplementedError( + f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}" + ) + + @torch.jit.export + def set_weight(self, weight: torch.Tensor) -> None: + if self.dtype in [torch.quint8, torch.quint4x2]: + self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) + else: + raise NotImplementedError( + "Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2." + ) + + @torch.jit.export + def _weight(self): + if self.dtype in [torch.quint8, torch.quint4x2]: + return torch.ops.quantized.embedding_bag_unpack(self._packed_weight) + else: + raise NotImplementedError( + "Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2." + ) + + def forward(self, x): + return x + + # Version 1 + # self + # |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase + # |--- dtype : torch.dtype + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_weight"] = self._weight() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.dtype = state_dict[prefix + "dtype"] + state_dict.pop(prefix + "dtype") + + weight = state_dict[prefix + "_packed_weight"] + state_dict.pop(prefix + "_packed_weight") + self.set_weight(weight) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return self._weight().__repr__() + + +class Embedding(torch.nn.Module): + r""" + A quantized Embedding module with quantized packed weights as inputs. + We adopt the same interface as `torch.nn.Embedding`, please see + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation. + + Similar to :class:`~torch.nn.Embedding`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. + + Examples:: + >>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12) + >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8]) + >>> output = m(indices) + >>> print(output.size()) + torch.Size([9, 12]) + + """ + + _version = 1 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + dtype=torch.quint8, + ) -> None: + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.dtype = dtype + + if _weight is None: + scales = torch.ones(num_embeddings, dtype=torch.float) + zero_points = torch.zeros(num_embeddings, dtype=torch.float) + qweight = torch._empty_per_channel_affine_quantized( + [num_embeddings, embedding_dim], + scales=scales, + zero_points=zero_points, + axis=0, + dtype=torch.quint8, + ) + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + qweight = _weight + + self._packed_params = EmbeddingPackedParams( + num_embeddings, embedding_dim, dtype + ) + self._packed_params.set_weight(qweight) + + def forward(self, indices: Tensor) -> Tensor: + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_4bit( + self._packed_params._packed_weight, indices + ) + else: + return torch.ops.quantized.embedding_byte( + self._packed_params._packed_weight, indices + ) + + def _get_name(self): + return "QuantizedEmbedding" + + def __repr__(self): + return _hide_packed_params_repr(self, EmbeddingPackedParams) + + def extra_repr(self): + extra_repr_str = ( + f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, " + f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}" + ) + + return extra_repr_str + + def set_weight(self, w: torch.Tensor) -> None: + self._packed_params.set_weight(w) + + def weight(self): + return self._packed_params._weight() + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized embedding module from a float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by user + """ + if hasattr(mod, "weight_fake_quant"): + assert type(mod) == torch.ao.nn.qat.Embedding, ( + "nnq." + + cls.__name__ + + ".from_float " + + "with fake quant only works for " + + torch.ao.nn.qat.Embedding.__name__ + ) + weight_observer = mod.weight_fake_quant + else: + assert type(mod) == nn.Embedding, ( + "nnq." + + cls.__name__ + + ".from_float only works for " + + nn.Embedding.__name__ + ) + assert hasattr(mod, "qconfig"), ( + "Embedding input float module must have qconfig defined" + ) + from torch.ao.quantization import float_qparams_weight_only_qconfig + + if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] + weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] + else: + weight_observer = float_qparams_weight_only_qconfig.weight() + + dtype = weight_observer.dtype + is_float_qparams_qconfig = ( + weight_observer.qscheme == torch.per_channel_affine_float_qparams + ) + assert is_float_qparams_qconfig, ( + "Embedding quantization is only supported with float_qparams_weight_only_qconfig." + ) + + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}" + ) + + # Run the observer to calculate qparams. + weight_observer(mod.weight) + qweight = _quantize_weight(mod.weight.float(), weight_observer) + + # Create quantized Embedding module and pass in the quantized weight + qembedding = Embedding(mod.num_embeddings, mod.embedding_dim) + qembedding.set_weight(qweight) + return qembedding + + @classmethod + def from_reference(cls, ref_embedding): + qembedding = cls( + ref_embedding.num_embeddings, + ref_embedding.embedding_dim, + ref_embedding.padding_idx, + ref_embedding.max_norm, + ref_embedding.norm_type, + ref_embedding.scale_grad_by_freq, + ref_embedding.sparse, + ref_embedding.get_quantized_weight(), + ref_embedding.weight_dtype, + ) + return qembedding + + +class EmbeddingBag(Embedding): + r""" + A quantized EmbeddingBag module with quantized packed weights as inputs. + We adopt the same interface as `torch.nn.EmbeddingBag`, please see + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation. + + Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`. + + Examples:: + >>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum') + >>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + >>> offsets = torch.tensor([0, 19, 20, 28, 28, 32]) + >>> output = m(indices, offsets) + >>> print(output.size()) + torch.Size([5, 12]) + + """ + + _version = 1 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "sum", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + dtype=torch.quint8, + ) -> None: + super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype) + + self.mode = mode + self.pruned_weights = False + self.include_last_offset = include_last_offset + self.dtype = dtype + + def forward( + self, + indices: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + compressed_indices_mapping: Optional[Tensor] = None, + ) -> Tensor: + if self.dtype == torch.quint4x2: + return torch.ops.quantized.embedding_bag_4bit( + self._packed_params._packed_weight, + indices, + offsets, + False, + 0, + self.pruned_weights, + per_sample_weights, + compressed_indices_mapping, + self.include_last_offset, + ) + else: + return torch.ops.quantized.embedding_bag_byte( + self._packed_params._packed_weight, + indices, + offsets, + False, + 0, + self.pruned_weights, + per_sample_weights, + compressed_indices_mapping, + self.include_last_offset, + ) + + def _get_name(self): + return "QuantizedEmbeddingBag" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized embedding_bag module from a float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by user + """ + if hasattr(mod, "weight_fake_quant"): + weight_observer = mod.weight_fake_quant + else: + assert type(mod) == nn.EmbeddingBag, ( + "nnq." + + cls.__name__ + + ".from_float only works for " + + nn.EmbeddingBag.__name__ + ) + assert hasattr(mod, "qconfig"), ( + "EmbeddingBag input float module must have qconfig defined" + ) + from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig + + if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr] + weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator] + else: + weight_observer = float_qparams_weight_only_qconfig.weight() + + dtype = weight_observer.dtype + is_float_qparams_qconfig = ( + weight_observer.qscheme == torch.per_channel_affine_float_qparams + ) + assert is_float_qparams_qconfig, ( + "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig." + ) + + assert dtype == torch.quint8 or dtype == torch.quint4x2, ( + f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}" + ) + + # Run the observer to calculate qparams. + weight_observer(mod.weight) + qweight = _quantize_weight(mod.weight.float(), weight_observer) + + # Create quantized EmbeddingBag module and pass in the quantized weight + qembedding_bag = EmbeddingBag( + mod.num_embeddings, + mod.embedding_dim, + max_norm=mod.max_norm, + norm_type=mod.norm_type, + scale_grad_by_freq=mod.scale_grad_by_freq, + mode=mod.mode, + sparse=mod.sparse, + include_last_offset=mod.include_last_offset, + dtype=dtype, + ) + qembedding_bag.set_weight(qweight) + return qembedding_bag + + @classmethod + def from_reference(cls, ref_embedding_bag): + qembedding_bag = cls( + ref_embedding_bag.num_embeddings, + ref_embedding_bag.embedding_dim, + ref_embedding_bag.max_norm, + ref_embedding_bag.norm_type, + ref_embedding_bag.scale_grad_by_freq, + ref_embedding_bag.mode, + ref_embedding_bag.sparse, + ref_embedding_bag.get_quantized_weight(), + ref_embedding_bag.include_last_offset, + ref_embedding_bag.weight_dtype, + ) + return qembedding_bag diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/functional_modules.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/functional_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..3b364b43f606071ad6bf3d20ae2b94e0a391829e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/functional_modules.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs + +import torch +from torch import Tensor +from torch._ops import ops + + +__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"] + + +class FloatFunctional(torch.nn.Module): + r"""State collector class for float operations. + + The instance of this class can be used instead of the ``torch.`` prefix for + some operations. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> f_add = FloatFunctional() + >>> a = torch.tensor(3.0) + >>> b = torch.tensor(4.0) + >>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def __init__(self) -> None: + super().__init__() + self.activation_post_process = torch.nn.Identity() + + def forward(self, x): + raise RuntimeError( + "FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + r = self.activation_post_process(r) + return r + + +class FXFloatFunctional(torch.nn.Module): + r"""module to replace FloatFunctional module before FX graph mode quantization, + since activation_post_process will be inserted in top level module directly + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def forward(self, x): + raise RuntimeError( + "FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + return r + + +class QFunctional(torch.nn.Module): + r"""Wrapper class for quantized operations. + + The instance of this class can be used instead of the + ``torch.ops.quantized`` prefix. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> q_add = QFunctional() + >>> # xdoctest: +SKIP + >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32) + >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32) + >>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + + def __init__(self) -> None: + super().__init__() + self.scale = 1.0 + self.zero_point = 0 + self.activation_post_process = torch.nn.Identity() + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict.pop(prefix + "scale")) + self.zero_point = int(state_dict.pop(prefix + "zero_point")) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _get_name(self): + return "QFunctional" + + def extra_repr(self): + return f"scale={self.scale}, zero_point={self.zero_point}" + + def forward(self, x): + raise RuntimeError( + "Functional is not intended to use the " + + "'forward'. Please use the underlying operation" + ) + + r"""Operation equivalent to ``torch.ops.quantized.add``""" + + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``""" + + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.add_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``""" + + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``""" + + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.mul_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.cat``""" + + def cat(self, x: list[Tensor], dim: int = 0) -> Tensor: + r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add_relu``""" + + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``""" + + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + assert type(mod) == FloatFunctional, ( + "QFunctional.from_float expects an instance of FloatFunctional" + ) + scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] + new_mod = QFunctional() + new_mod.scale = float(scale) + new_mod.zero_point = int(zero_point) + return new_mod diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/linear.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..9042833f5e30b2ef8cc779345ae6ab542f78c051 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/linear.py @@ -0,0 +1,363 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.nn as nn +from torch.nn.utils.fusion import fuse_linear_bn_weights +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import _hide_packed_params_repr, _quantize_weight, WeightedQuantizedModule + + +__all__ = ["LinearPackedParams", "Linear"] + + +class LinearPackedParams(torch.nn.Module): + _version = 3 + + def __init__(self, dtype=torch.qint8): + super().__init__() + self.dtype = dtype + if self.dtype == torch.qint8: + wq = torch._empty_affine_quantized( + [1, 1], scale=1.0, zero_point=0, dtype=torch.qint8 + ) + elif self.dtype == torch.float16: + wq = torch.zeros([1, 1], dtype=torch.float) + self.set_weight_bias(wq, None) # type: ignore[possibly-undefined] + + @torch.jit.export + def set_weight_bias( + self, weight: torch.Tensor, bias: Optional[torch.Tensor] + ) -> None: + if self.dtype == torch.qint8: + self._packed_params = torch.ops.quantized.linear_prepack(weight, bias) + elif self.dtype == torch.float16: + self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + + @torch.jit.export + def _weight_bias(self): + if self.dtype == torch.qint8: + return torch.ops.quantized.linear_unpack(self._packed_params) + elif self.dtype == torch.float16: + return torch.ops.quantized.linear_unpack_fp16(self._packed_params) + else: + raise RuntimeError("Unsupported dtype on dynamic quantized linear!") + + def forward(self, x): + return x + + # Version 1 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 2 + # self + # |--- weight : Tensor + # |--- bias : Tensor + # |--- dtype : torch.dtype + # + # Version 3 + # self + # |--- _packed_params : (Tensor, Tensor) representing (weight, bias) + # of LinearPackedParams + # |--- dtype : torch.dtype + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_params"] = self._weight_bias() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + self.dtype = torch.qint8 + else: + self.dtype = state_dict[prefix + "dtype"] + state_dict.pop(prefix + "dtype") + + if version is None or version < 3: + self.set_weight_bias( + state_dict[prefix + "weight"], state_dict[prefix + "bias"] + ) + state_dict.pop(prefix + "weight") + state_dict.pop(prefix + "bias") + + if version == 3: + weight, bias = state_dict[prefix + "_packed_params"] + state_dict.pop(prefix + "_packed_params") + self.set_weight_bias(weight, bias) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def __repr__(self): + return self._weight_bias().__repr__() + + +class Linear(WeightedQuantizedModule): + r""" + A quantized linear module with quantized tensor as inputs and outputs. + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. + + Similar to :class:`~torch.nn.Linear`, attributes will be randomly + initialized at module creation time and will be overwritten later + + Attributes: + weight (Tensor): the non-learnable quantized weights of the module of + shape :math:`(\text{out\_features}, \text{in\_features})`. + bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized to zero. + scale: `scale` parameter of output Quantized Tensor, type: double + zero_point: `zero_point` parameter for output Quantized Tensor, type: long + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> m = nn.quantized.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> # xdoctest: +SKIP + >>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + _version = 3 + _FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear) + + def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8): + super().__init__() + # We don't muck around with buffers or attributes or anything here + # to keep the module simple. *everything* is simply a Python attribute. + # Serialization logic is explicitly handled in the below serialization and + # deserialization modules + self.in_features = in_features + self.out_features = out_features + bias = None + if bias_: + bias = torch.zeros(out_features, dtype=torch.float) + + if dtype == torch.qint8: + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + elif dtype == torch.float16: + qweight = torch.zeros([out_features, in_features], dtype=torch.float) + else: + raise RuntimeError("Unsupported dtype specified for quantized Linear!") + + self._packed_params = LinearPackedParams(dtype) + self._packed_params.set_weight_bias(qweight, bias) + self.scale = 1.0 + self.zero_point = 0 + + def _get_name(self): + return "QuantizedLinear" + + def extra_repr(self): + return ( + f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, " + f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}" + ) + + def __repr__(self): + return _hide_packed_params_repr(self, LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + # ===== Serialization methods ===== + # The special consideration here is that we have to unpack the weights into their + # regular QTensor form for serialization. Packed weights should not live + # outside the process in which they were created, rather they should be derived + # from the QTensor weight. + # + # Version 1 + # self + # |--- scale : float + # |--- zero_point : int + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 2 + # self + # |--- scale : float + # |--- zero_point : int + # |--- _packed_params : Module + # |--- weight : Tensor + # |--- bias : Tensor + # + # Version 3 + # self + # |--- scale : float + # |--- zero_point : int + # |--- _packed_params : Module + # |--- _packed_params : (Tensor, Tensor) representing weight, bias + # of LinearPackedParams C++ struct + # + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + # ===== Deserialization methods ===== + # Counterpart to the serialization methods, we must pack the serialized QTensor + # weight into its packed format for use by the FBGEMM ops. + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + + version = local_metadata.get("version", None) + + if version is None or version == 1: + # We moved the parameters into a LinearPackedParameters submodule + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias") + state_dict.update( + { + prefix + "_packed_params.weight": weight, + prefix + "_packed_params.bias": bias, + } + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + # Function rather than property to make sure that JIT serialization doesn't + # register this as an attribute + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params.set_weight_bias(w, b) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized module from an observed float module + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + use_precomputed_fake_quant (bool): if True, the module will reuse min/max + values from the precomputed fake quant module. + """ + if hasattr(mod, "weight_fake_quant"): + if type_before_parametrizations(mod) == nniqat.LinearBn1d: + mod.weight, mod.bias = fuse_linear_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + # This function does not participate in JIT, so it is OK to ignore + # the type mismatch in assignment. Also, mypy has an issue with + # iterables not being implemented, so we are ignoring those too. + if not isinstance(cls._FLOAT_MODULE, Iterable): + cls._FLOAT_MODULE = [cls._FLOAT_MODULE] + supported_modules = ", ".join( + [float_mod.__name__ for float_mod in cls._FLOAT_MODULE] + ) + error_msg = f"nnq.{cls.__name__}.from_float only works for {supported_modules}, but got: {type(mod)}" + assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, ( + error_msg.format() + ) + assert hasattr(mod, "qconfig"), ( + "Input float module must have qconfig defined" + ) + activation_post_process = mod.activation_post_process + if type_before_parametrizations(mod) == nni.LinearReLU: + mod = mod[0] + weight_post_process = ( + mod.qconfig.weight() + if not hasattr(mod, "weight_fake_quant") + else mod.weight_fake_quant + ) + + if not use_precomputed_fake_quant: + # Observer may not have been called yet + # Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) + qlinear.set_weight_bias(qweight, mod.bias) + qlinear.scale = float(act_scale) + qlinear.zero_point = int(act_zp) + return qlinear + + @classmethod + def from_reference(cls, ref_qlinear, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + + Args: + ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features) + qweight = ref_qlinear.get_quantized_weight() + qlinear.set_weight_bias(qweight, ref_qlinear.bias) + + qlinear.scale = float(output_scale) + qlinear.zero_point = int(output_zero_point) + return qlinear diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/normalization.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..4db2ac6e928f47236eeab43e63d399452112a263 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/normalization.py @@ -0,0 +1,347 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = [ + "LayerNorm", + "GroupNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", +] + + +class LayerNorm(torch.nn.LayerNorm): + r"""This is the quantized version of :class:`~torch.nn.LayerNorm`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + normalized_shape, + weight, + bias, + scale, + zero_point, + eps=1e-5, + elementwise_affine=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + **factory_kwargs, + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.layer_norm( + input, + self.normalized_shape, + weight=self.weight, + bias=self.bias, + eps=self.eps, + output_scale=self.scale, + output_zero_point=self.zero_point, + ) + + def _get_name(self): + return "QuantizedLayerNorm" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.normalized_shape, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.elementwise_affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.normalized_shape, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.elementwise_affine, + ) + + +class GroupNorm(torch.nn.GroupNorm): + r"""This is the quantized version of :class:`~torch.nn.GroupNorm`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + __constants__ = ["num_groups", "num_channels", "eps", "affine"] + + def __init__( + self, + num_groups, + num_channels, + weight, + bias, + scale, + zero_point, + eps=1e-5, + affine=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.group_norm( + input, + self.num_groups, + self.weight, + self.bias, + self.eps, + self.scale, + self.zero_point, + ) + + def _get_name(self): + return "QuantizedGroupNorm" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_groups, + mod.num_channels, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + +class InstanceNorm1d(torch.nn.InstanceNorm1d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm1d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + + +class InstanceNorm2d(torch.nn.InstanceNorm2d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm2d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + + +class InstanceNorm3d(torch.nn.InstanceNorm3d): + r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`. + + Additional args: + * **scale** - quantization scale of the output, type: double. + * **zero_point** - quantization zero point of the output, type: long. + + """ + + def __init__( + self, + num_features, + weight, + bias, + scale, + zero_point, + eps=1e-5, + momentum=0.1, + affine=False, + track_running_stats=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) + self.weight = weight + self.bias = bias + self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) + self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.instance_norm( + input, self.weight, self.bias, self.eps, self.scale, self.zero_point + ) + + def _get_name(self): + return "QuantizedInstanceNorm3d" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) + return new_mod + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls( + mod.num_features, + mod.weight, + mod.bias, + float(scale), + int(zero_point), + mod.eps, + mod.affine, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/rnn.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5076c9225d2eb00f0b60bd648b6f72833f7ee8e2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/rnn.py @@ -0,0 +1,58 @@ +# mypy: allow-untyped-defs +import torch + + +__all__ = [ + "LSTM", +] + + +class LSTM(torch.ao.nn.quantizable.LSTM): + r"""A quantized long short-term memory (LSTM). + + For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` + + Attributes: + layers : instances of the `_LSTMLayer` + + .. note:: + To access the weights and biases, you need to access them per layer. + See examples in :class:`~torch.ao.nn.quantizable.LSTM` + + Examples:: + >>> # xdoctest: +SKIP + >>> custom_module_config = { + ... 'float_to_observed_custom_module_class': { + ... nn.LSTM: nn.quantizable.LSTM, + ... }, + ... 'observed_to_quantized_custom_module_class': { + ... nn.quantizable.LSTM: nn.quantized.LSTM, + ... } + ... } + >>> tq.prepare(model, prepare_custom_module_class=custom_module_config) + >>> tq.convert(model, convert_custom_module_class=custom_module_config) + """ + + _FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment] + + def _get_name(self): + return "QuantizedLSTM" + + @classmethod + def from_float(cls, *args, **kwargs): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError( + "It looks like you are trying to convert a " + "non-observed LSTM module. Please, see " + "the examples on quantizable LSTMs." + ) + + @classmethod + def from_observed(cls, other): + assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type] + converted = torch.ao.quantization.convert( + other, inplace=False, remove_qconfig=True + ) + converted.__class__ = cls + return converted diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be59d496b8d07a3861b4420e25946e75e6eb0db7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/modules/utils.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +import abc +import collections +import itertools + +import torch +from torch.nn.modules.module import _addindent + + +__all__ = [ + "WeightedQuantizedModule", +] + + +class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta): + """Wrapper for quantized modules than can be lowered from reference modules.""" + + @classmethod + @abc.abstractmethod + def from_reference(cls, ref_module, output_scale, output_zero_point): + raise NotImplementedError + + +def _get_weight_observer(observer): + # FakeQuantize observer + if hasattr(observer, "activation_post_process"): + observer = observer.activation_post_process + # UniformQuantizationObserverBase observer + return observer + + +def _needs_weight_clamping(observer, dtype): + observer = _get_weight_observer(observer) + if dtype in [torch.qint8, torch.quint8, torch.qint32]: + info = torch.iinfo(dtype) + return observer.quant_min > info.min or observer.quant_max < info.max + return False + + +def _clamp_weights(qweight, observer, scale, zp): + if not _needs_weight_clamping(observer, qweight.dtype): + return qweight + + observer = _get_weight_observer(observer) + min_, max_ = observer.quant_min, observer.quant_max + + # Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet. + qw_int_max = torch.clone(qweight.int_repr()).fill_(max_) + qw_int_min = torch.clone(qweight.int_repr()).fill_(min_) + qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max) + + if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: + qweight = torch._make_per_tensor_quantized_tensor( + qw_int, scale.item(), zp.item() + ) + elif observer.qscheme in [ + torch.per_channel_symmetric, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + qweight = torch._make_per_channel_quantized_tensor( + qw_int, scale, zp, axis=observer.ch_axis + ) + else: + raise ValueError("Unexpected qscheme " + observer.qscheme) + return qweight + + +def _quantize_weight(float_wt, observer): + wt_scale, wt_zp = observer.calculate_qparams() + if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]: + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.qint8 + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]: + wt_axis = observer.ch_axis + qweight = torch.quantize_per_channel( + float_wt, + wt_scale.to(torch.double), + wt_zp.to(torch.int64), + wt_axis, + torch.qint8, + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + elif observer.qscheme in [torch.per_channel_affine_float_qparams]: + qweight = torch.quantize_per_channel( + float_wt, + wt_scale.to(torch.float), + wt_zp.to(torch.float), + observer.ch_axis, + observer.dtype, + ) + qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp) + else: + raise ValueError("Unexpected qscheme " + observer.qscheme) + return qweight + + +def _ntuple_from_first(n): + """Converts the argument to a tuple of size n + with the first element repeated.""" + + def parse(x): + while isinstance(x, collections.abc.Sequence): + if len(x) == n: + break + x = x[0] + return tuple(itertools.repeat(x, n)) + + return parse + + +def _hide_packed_params_repr(self, params): + # We don't want to show `PackedParams` children, hence custom + # `__repr__`. This is the same as nn.Module.__repr__, except the check + # for the `params module`. + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, module in self._modules.items(): + if isinstance(module, params): + continue + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append("(" + key + "): " + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + "(" + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + + main_str += ")" + return main_str + + +_pair_from_first = _ntuple_from_first(2) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e15e9c1516d30f7ca9ee47b21b267533de75b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__init__.py @@ -0,0 +1,19 @@ +from .modules import * # noqa: F403 + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "GRU", + "Embedding", + "EmbeddingBag", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f9b95c499b5a842e2cd4b3581786efe5b1d661a Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe97c22f5a46a5eafc1432075fc57dd44c3aa8d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py @@ -0,0 +1,29 @@ +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, +) +from .linear import Linear +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell +from .sparse import Embedding, EmbeddingBag + + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "RNNCell", + "LSTMCell", + "GRUCell", + "LSTM", + "GRU", + "Embedding", + "EmbeddingBag", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e5f717d8d9b34db2819e951394250667347a31b Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb8ae560be3854273dad6dc89dcfdc549323e47 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633c8920e2bcafc59bfc04d639865fac878be577 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bf51eb216a764fa946185c23db958d334db3dcf Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21afdb7342b6f4688982072cee7a2b1ab4a23360 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba0f57be4b90199b2e8225b038891f7a0617d8b5 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/conv.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4def5c4b7a0ddada6d355efdb016b60970a959 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/conv.py @@ -0,0 +1,511 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.common_types import _size_1_t + +from .utils import ReferenceQuantizedModule + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] + + +class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule): + """A reference version of nn.quantized.Conv2d + we will not pack the parameters in this module, since weight packing is an + optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), + this is useful when user want to use this module in other backends like Glow. + """ + + __annotations__ = {"bias": Optional[torch.Tensor]} + _IS_REFERENCE = True + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + + +class Conv1d(_ConvNd, nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.Conv1d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv1d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + weight_quant_dequant = self.get_weight() + result = F.conv1d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class Conv2d(_ConvNd, nn.Conv2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.Conv2d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv2d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + weight_quant_dequant = self.get_weight() + result = F.conv2d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv2d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class Conv3d(_ConvNd, nn.Conv3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.Conv3d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.conv3d --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.conv3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + weight_quant_dequant = self.get_weight() + result = F.conv3d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return result + + def _get_name(self): + return "QuantizedConv3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvNd.from_float(cls, float_conv, weight_qparams) + + +class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd): + """A reference version of nn.quantized.ConvTranspose2d + we will not pack the parameters in this module, since weight packing is an + optimization for quantized backends supported in PyTorch (fbgemm/qnnpack), + this is useful when user want to use this module in other backends like Glow. + """ + + @staticmethod + def from_float(cls, float_conv, weight_qparams): + qref_conv = cls( + float_conv.in_channels, + float_conv.out_channels, + float_conv.kernel_size, # type: ignore[arg-type] + float_conv.stride, # type: ignore[arg-type] + float_conv.padding, # type: ignore[arg-type] + float_conv.output_padding, # type: ignore[arg-type] + float_conv.groups, + float_conv.bias is not None, # type: ignore[arg-type] + float_conv.dilation, # type: ignore[arg-type] + float_conv.padding_mode, + device=float_conv.weight.device, + dtype=float_conv.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach()) + if float_conv.bias is not None: + qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach()) + return qref_conv + + +class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = "zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.ConvTranspose1d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: Optional[list[int]] = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose1d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv1d + """ + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose1d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + return result + + def _get_name(self): + return "QuantizedConvTranspose1d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) + + +class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.ConvTranspose2d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: Optional[list[int]] = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose2d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv2d + """ + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose2d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + + return result + + def _get_name(self): + return "QuantizedConvTranspose2d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) + + +class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + padding_mode="zeros", + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ): + nn.ConvTranspose3d.__init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + groups, + bias, + dilation, + padding_mode, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def forward( + self, x: torch.Tensor, output_size: Optional[list[int]] = None + ) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.convTranspose3d --- + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized conv3d + """ + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + output_padding = self._output_padding( + input, # type: ignore[arg-type] + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + self.dilation, # type: ignore[arg-type] + ) + + weight_quant_dequant = self.get_weight() + result = F.conv_transpose3d( + x, + weight_quant_dequant, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) + return result + + def _get_name(self): + return "QuantizedConvTranspose3d(Reference)" + + @classmethod + def from_float(cls, float_conv, weight_qparams): # type: ignore[override] + return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/linear.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..67f4aee33ba340130bf2b01dfe2ed2c06b96b23e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/linear.py @@ -0,0 +1,69 @@ +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import ReferenceQuantizedModule + + +__all__ = ["Linear"] + + +class Linear(nn.Linear, ReferenceQuantizedModule): + """A reference quantized linear module that fits into the FX + Graph Mode Quantization workflow + activation will be floating point Tensor, we will store floating + point weight as well in the module, but in forward we'll quantize + and dequantize the weight before running the floating point functional + linear operator. + """ + + _IS_REFERENCE = True + + def __init__( + self, + in_features: int, + out_features: int, + bias_: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_qparams: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__(in_features, out_features, bias_, device, dtype) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self) -> str: + return "QuantizedLinear(Reference)" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.linear --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.linear --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized linear + """ + weight_quant_dequant = self.get_weight() + result = F.linear(x, weight_quant_dequant, self.bias) + return result + + @classmethod + def from_float( + cls, float_linear: nn.Linear, weight_qparams: dict[str, Any] + ) -> "Linear": + qref_linear = Linear( + float_linear.in_features, + float_linear.out_features, + float_linear.bias is not None, + device=float_linear.weight.device, + dtype=float_linear.weight.dtype, + weight_qparams=weight_qparams, + ) + qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) + if float_linear.bias is not None: + qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) + return qref_linear diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..adb1356cb3d360a240331d2e0150f8080bfa7314 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py @@ -0,0 +1,853 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch +import torch.nn as nn +from torch import _VF, Tensor +from torch.nn.utils.rnn import PackedSequence + +from .utils import _quantize_and_dequantize_weight, _quantize_weight + + +__all__ = [ + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", + "RNNBase", + "LSTM", + "GRU", + "get_quantized_weight", +] + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +def _get_weight_and_quantization_params(module, wn): + weight = getattr(module, wn) + params = [weight] + for param_name in [ + wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"] + ]: + if hasattr(module, param_name): + param = getattr(module, param_name) + else: + param = None + params.append(param) + return params + + +def get_quantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_weight(*params) + return weight + + +def _get_quantize_and_dequantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_and_dequantize_weight(*params) + return weight + + +class RNNCellBase(nn.RNNCellBase): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + weight_qparams_dict=None, + ) -> None: + super().__init__( + input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype + ) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + weight_qparams_dict = { + "weight_ih": weight_qparams, + "weight_hh": weight_qparams, + "is_decomposed": False, + } + assert len(weight_qparams_dict) == 3, ( + "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" + ) + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + assert weight_qparams_dict is not None + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + # TODO: refactor the duplicated code to utils.py + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + ], Exception( + f"qscheme: {weight_qscheme} is not support in {self._get_name()}" + ) + if weight_qscheme is not None: + scale = weight_qparams["scale"] + scale_tensor = ( + scale.detach().clone() + if isinstance(scale, torch.Tensor) + else torch.tensor(scale, dtype=torch.float, device=device) + ) + self.register_buffer(key + "_scale", scale_tensor) + zp = weight_qparams["zero_point"] + zp_tensor = ( + zp.detach().clone() + if isinstance(zp, torch.Tensor) + else torch.tensor(zp, dtype=torch.int, device=device) + ) + self.register_buffer(key + "_zero_point", zp_tensor) + if weight_qscheme == torch.per_channel_affine: + axis = weight_qparams["axis"] + axis_tensor = ( + axis.detach().clone() + if isinstance(axis, torch.Tensor) + else torch.tensor(axis, dtype=torch.int, device=device) + ) + self.register_buffer(key + "_axis", axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + + def _get_name(self): + return "QuantizedRNNCellBase(Reference)" + + def get_quantized_weight_ih(self): + return get_quantized_weight(self, "weight_ih") + + def get_quantized_weight_hh(self): + return get_quantized_weight(self, "weight_hh") + + def get_weight_ih(self): + return _get_quantize_and_dequantized_weight(self, "weight_ih") + + def get_weight_hh(self): + return _get_quantize_and_dequantized_weight(self, "weight_hh") + + +class RNNCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def _get_name(self): + return "QuantizedRNNCell(Reference)" + + # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input + # and remove duplicated code, same for the other two Cell modules + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + assert input.dim() in ( + 1, + 2, + ), ( + f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.nonlinearity, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class LSTMCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def _get_name(self): + return "QuantizedLSTMCell(Reference)" + + def forward( + self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None + ) -> tuple[Tensor, Tensor]: + assert input.dim() in ( + 1, + 2, + ), ( + f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class GRUCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + factory_kwargs = { + "device": device, + "dtype": dtype, + "weight_qparams_dict": weight_qparams_dict, + } + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def _get_name(self): + return "QuantizedGRUCell(Reference)" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + assert input.dim() in ( + 1, + 2, + ), ( + f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, + hx, + self.get_weight_ih(), + self.get_weight_hh(), + self.bias_ih, + self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict, + ) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + + +class RNNBase(nn.RNNBase): + def __init__( + self, + mode: str, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + weight_qparams_dict: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + mode, + input_size, + hidden_size, + num_layers, + bias, + batch_first, + dropout, + bidirectional, + proj_size, + device, + dtype, + ) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item] + for wn in self._flat_weights_names: + if wn.startswith("weight"): + weight_qparams_dict[wn] = weight_qparams + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + ], Exception( + f"qscheme: {weight_qscheme} is not support in {self._get_name()}" + ) + if weight_qscheme is not None: + self.register_buffer( + key + "_scale", + torch.tensor( + weight_qparams["scale"], dtype=torch.float, device=device + ), + ) + self.register_buffer( + key + "_zero_point", + torch.tensor( + weight_qparams["zero_point"], dtype=torch.int, device=device + ), + ) + if weight_qscheme == torch.per_channel_affine: + self.register_buffer( + key + "_axis", + torch.tensor( + weight_qparams["axis"], dtype=torch.int, device=device + ), + ) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + + +class LSTM(RNNBase): + """Reference Quantized LSTM Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + + def __init__(self, *args, **kwargs): + super().__init__("LSTM", *args, **kwargs) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden( # type: ignore[override] + self, + hx: tuple[Tensor, Tensor], + permutation: Optional[Tensor], + ) -> tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) + + def get_expected_cell_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args( # type: ignore[override] + self, + input: Tensor, + hidden: tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ): + self.check_input(input, batch_sizes) + self.check_hidden_size( + hidden[0], + self.get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {}, got {}", + ) + self.check_hidden_size( + hidden[1], + self.get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {}, got {}", + ) + + def get_quantized_weight_bias_dict(self): + """dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = ( + self.proj_size if self.proj_size > 0 else self.hidden_size + ) + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (h_zeros, c_zeros) + else: + if batch_sizes is None: # If not PackedSequence input. + if is_batched: # type: ignore[possibly-undefined] + if hx[0].dim() != 3 or hx[1].dim() != 3: + msg = ( + "For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ( + "For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.lstm( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.lstm( + input, + batch_sizes, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedLSTM(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict, + ) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod + + +class GRU(RNNBase): + """Reference Quantized GRU Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + + def __init__(self, *args, **kwargs): + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + super().__init__("GRU", *args, **kwargs) + + def get_quantized_weight_bias_dict(self): + """dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py + # only changed self._flat_weights to self.get_flat_weights() + # TODO: maybe we can try inheriting from that class and define get_flat_weights + # as a @property? this might interfere with TorchScript, if we remove that + # requirement in the future we should be able to do this + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + assert input.dim() in ( + 2, + 3, + ), ( + f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" + ) + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru( + input, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = _VF.gru( + input, + batch_sizes, + hx, + self.get_flat_weights(), + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedGRU(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict, + ) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4bdb9b02c71c8dd5c90db43bd4742f7dbf152c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/sparse.py @@ -0,0 +1,162 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from .utils import ReferenceQuantizedModule + + +__all__ = ["Embedding", "EmbeddingBag"] + + +class Embedding(nn.Embedding, ReferenceQuantizedModule): + """A reference quantized Embedding module that fits into the + FX Graph Mode Quantization workflow, activation will be floating point Tensor, + we will store floating point weight as well in the module, but in forward we'll + quantize and dequantize the weight before running the floating point functional + embedding operator. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self): + return "QuantizedEmbedding(Reference)" + + def forward(self, input: Tensor) -> Tensor: + weight_quant_dequant = self.get_weight() + return F.embedding( + input, + weight_quant_dequant, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + @classmethod + def from_float(cls, mod, weight_qparams): + return cls( + mod.num_embeddings, + mod.embedding_dim, + mod.padding_idx, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.sparse, + mod.weight, + mod.weight.device, + mod.weight.dtype, + weight_qparams, + ) + + +class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule): + """A reference quantized EmbeddingBag module that fits into the + FX Graph Mode Quantization workflow, activation will be floating point Tensor, + we will store floating point weight as well in the module, but in forward we'll + quantize and dequantize the weight before running the floating point functional + embedding operator. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + device=None, + dtype=None, + weight_qparams: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + max_norm, + norm_type, + scale_grad_by_freq, + mode, + sparse, + _weight, + include_last_offset, + padding_idx, + device, + dtype, + ) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self): + return "QuantizedEmbedding(Reference)" + + def forward( + self, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + weight_quant_dequant = self.get_weight() + return F.embedding_bag( + input, + weight_quant_dequant, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) + + @classmethod + def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False): + return cls( + mod.num_embeddings, + mod.embedding_dim, + mod.max_norm, + mod.norm_type, + mod.scale_grad_by_freq, + mod.mode, + mod.sparse, + mod.weight, + mod.include_last_offset, + mod.padding_idx, + mod.weight.device, + mod.weight.dtype, + weight_qparams, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0701b73da38b0e252380b0c58265e16960e66e01 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/quantized/reference/modules/utils.py @@ -0,0 +1,434 @@ +# mypy: allow-untyped-defs +import typing + +import torch + + +__all__ = [ + "ReferenceQuantizedModule", +] + + +class ReferenceQuantizedModule(torch.nn.Module): + def _init_weight_qparams(self, weight_qparams, device): + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0, + } + self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [ + None, + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ], ( + f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}" + ) + if self.weight_dtype in [ + torch.quint8, + torch.qint8, + torch.quint4x2, + torch.qint32, + ]: + zero_point_dtype = ( + weight_qparams["zero_point"].dtype + if isinstance(weight_qparams["zero_point"], torch.Tensor) + else torch.int + ) + w_scale = weight_qparams["scale"] + w_scale_tensor = ( + w_scale.detach().clone() + if isinstance(w_scale, torch.Tensor) + else torch.tensor(w_scale, dtype=torch.float, device=device) + ) + self.register_buffer("weight_scale", w_scale_tensor) + w_zp = weight_qparams["zero_point"] + w_zp_tensor = ( + w_zp.detach().clone() + if isinstance(w_zp, torch.Tensor) + else torch.tensor(w_zp, dtype=zero_point_dtype, device=device) + ) + self.register_buffer("weight_zero_point", w_zp_tensor) + if self.weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + w_axis = weight_qparams["axis"] + w_axis_tensor = ( + w_axis.detach().clone() + if isinstance(w_axis, torch.Tensor) + else torch.tensor(w_axis, dtype=torch.int, device=device) + ) + self.register_buffer("weight_axis", w_axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + else: + # added for TorchScriptability, and for torch.float + self.register_buffer( + "weight_scale", torch.tensor(1.0, dtype=torch.float, device=device) + ) + self.register_buffer( + "weight_zero_point", torch.tensor(0, dtype=torch.int, device=device) + ) + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device) + ) + self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) + # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export + # for capturing `.item` operations + self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] + self.weight_quant_min: typing.Optional[int] = weight_qparams.get( + "quant_min", None + ) + self.weight_quant_max: typing.Optional[int] = weight_qparams.get( + "quant_max", None + ) + + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + if self.is_decomposed: + return _quantize_and_dequantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max, + ) + else: + return _quantize_and_dequantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + ) + + def get_quantized_weight(self): + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + # assert isinstance(self.weight_axis, torch.Tensor) + if self.is_decomposed: + return _quantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max, + ) + else: + return _quantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + _save_weight_qparams( + destination, + prefix, + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +def _quantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (int(-(2**31)), int(2**31 - 1)), + } + + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ + weight_dtype_ + ] + weight = torch.ops.quantized_decomposed.quantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[ + weight_dtype_ + ] + weight = torch.ops.quantized_decomposed.quantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + # TODO: get the quant_min and quant_max from activation_post_process + _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (int(-(2**31)), int(2**31 - 1)), + } + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_] + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_, + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _quantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, +) -> torch.Tensor: + if weight_dtype == torch.float16: + weight = weight.to(weight_dtype) + return weight + + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.quantize_per_tensor( + weight, weight_scale, weight_zero_point, weight_dtype + ) + return weight + elif weight_qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]: + weight = torch.quantize_per_channel( + weight, weight_scale, weight_zero_point, weight_axis_int, weight_dtype + ) # type: ignore[arg-type] + return weight + raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + + +def _quantize_and_dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + """Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + weight_quant = _quantize_weight_decomposed( + weight, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + weight_quant_min, + weight_quant_max, + ) + weight_dequant = _dequantize_weight_decomposed( + weight_quant, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + weight_quant_min, + weight_quant_max, + ) + else: + weight_dequant = weight + return weight_dequant + + +def _quantize_and_dequantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, +) -> torch.Tensor: + """Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ]: + weight_quant = _quantize_weight( + weight, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis_int, + ) + weight_dequant = weight_quant.dequantize() + else: + weight_dequant = weight + return weight_dequant + + +def _save_weight_qparams( + destination, + prefix, + weight_qscheme, + weight_dtype, + weight_scale, + weight_zero_point, + weight_axis, +): + destination[prefix + "weight_qscheme"] = weight_qscheme + destination[prefix + "weight_dtype"] = weight_dtype + if weight_qscheme is not None: + destination[prefix + "weight_scale"] = weight_scale + destination[prefix + "weight_zero_point"] = weight_zero_point + if weight_qscheme == torch.per_channel_affine: + destination[prefix + "weight_axis"] = weight_axis + + +def _get_weight_qparam_keys(state_dict: dict[str, typing.Any], prefix: str): + keys = ["weight_qscheme", "weight_dtype"] + weight_qscheme = state_dict[prefix + "weight_qscheme"] + if weight_qscheme is not None: + keys.append("weight_scale") + keys.append("weight_zero_point") + if weight_qscheme == torch.quantize_per_channel: + keys.append("weight_axis") + return keys diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fda5a58f2984ee05b0d167297b458f62c37fc59 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/__init__.py @@ -0,0 +1 @@ +from . import quantized diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..932ec6f4f42bf3a303374125dede26971aaa8837 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef66c90b0e8ecdbc7cd2cfb4c1cecf0bc38e8466 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__init__.py @@ -0,0 +1,10 @@ +from torch.ao.nn.sparse.quantized import dynamic + +from .linear import Linear, LinearPackedParams + + +__all__ = [ + "dynamic", + "Linear", + "LinearPackedParams", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e46b03e013d468570622a2dcbca571ee6e1e6e2d Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f07b09aba548d8c937647723e7d8dd4a365a9a0e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f12609b9436328242a78054f1628265d6140343 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91ecfd8793dc08b96ed64f47f531724aa8a866d0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__init__.py @@ -0,0 +1,6 @@ +from .linear import Linear + + +__all__ = [ + "Linear", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..379e0b450cd4a9424ae9eceeee88fe562c210ab0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22d2bfac91cc4e2a15fe6e9500e8de4fde368fdb Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..6da18e151012128fbe935f790513e603fc372a7a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -0,0 +1,189 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.ao.nn.intrinsic as nni +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _quantize_weight, +) +from torch.ao.nn.sparse.quantized import linear +from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern + + +__all__ = ["Linear"] + + +class Linear(torch.nn.Module): + r""" + A dynamically quantized sparse linear module with float tensor as inputs and outputs. + """ + + _version = 1 + _op_type = "sparse_dynamic" + _FLOAT_MODULE = torch.nn.Linear + + def __init__( + self, + in_features, + out_features, + row_block_size, + col_block_size, + bias=True, + dtype=torch.qint8, + ): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError( + "Only QINT8 is supported for Sparse Quantized Linear Dynamic" + ) + + self.in_features = in_features + self.out_features = out_features + + if bias: + bias = torch.zeros(self.out_features, dtype=torch.float) + else: + bias = None + + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + self._packed_params = linear.LinearPackedParams( + row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype + ) + self._packed_params.set_weight_bias( + qweight, bias, row_block_size, col_block_size + ) + + def _get_name(self): + return "SparseQuantizedDynamicLinear" + + def extra_repr(self): + return f"in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}" + + def __repr__(self): + return _hide_packed_params_repr(self, linear.LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "op_type"] = self._op_type + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + op_type = int(state_dict[prefix + "op_type"]) + assert op_type == "sparse", ( + f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]" + ) + state_dict.pop(prefix + "op_type") + + version = local_metadata.get("version", None) + assert version <= self._version + + # Is this code valid? In old quantization it seemed to be used to load + # older model + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias") + state_dict.update( + { + prefix + "_packed_params.weight": weight, + prefix + "_packed_params.bias": bias, + } + ) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias( + self, + w: torch.Tensor, + b: Optional[torch.Tensor], + row_block_size: Optional[int], + col_block_size: Optional[int], + ) -> None: + assert row_block_size is not None and col_block_size is not None + self.out_features = w.shape[0] + self.in_features = w.shape[1] + self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized sparse dynamic module from a float module. + + We only care about the convert at this stage, no need for observers just yet. + """ + assert type(mod) == cls._FLOAT_MODULE, ( + " nnq." + + cls.__name__ + + ".from_float only works for " + + cls._FLOAT_MODULE.__name__ + ) + # TODO: Need to add options to qconfig to avoid the calibration. + # TODO: Add calibration for the sparsity + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + if type(mod) == nni.LinearReLU: + mod = mod[0] + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + + weight_observer = default_dynamic_qconfig.weight() + + # It is important to multiply by the mask BEFORE calling the `weight_observer` + # TODO (zaf): Mask might not be part of the qconfig (T83295194) + weight = mod.weight + if getattr(mod.qconfig, "mask", False): + weight = mod.qconfig.mask * mod.weight + + weight_observer(weight) + dtype = weight_observer.dtype + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + _w_sc, w_zp = weight_observer.calculate_qparams() + if isinstance(w_zp, torch.Tensor): + assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" + else: + assert w_zp == 0, "Weight zero point must map to 0" + qweight = _quantize_weight(weight.float(), weight_observer) + + row_block_size, col_block_size = LinearBlockSparsePattern.block_size() + qlinear = cls( + mod.in_features, + mod.out_features, + row_block_size, + col_block_size, + dtype=dtype, + ) + qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) + return qlinear diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/linear.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e3dbf23b9f682ce6e930b4a7f63677ceafe52e71 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/linear.py @@ -0,0 +1,275 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch.ao.nn.quantized.modules.utils import ( + _hide_packed_params_repr, + _quantize_weight, +) + + +__all__ = ["LinearPackedParams", "Linear"] + + +# TODO (zaf): Inherit from `quantized.LinearPackedParams` (T83294430) +class LinearPackedParams(torch.nn.Module): + _version = 1 + + def __init__(self, row_block_size=1, col_block_size=4, dtype=torch.qint8): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError("Linear prepacking only supports QINT8") + self.dtype = dtype + wq = torch._empty_affine_quantized( + [1, 1], scale=1.0, zero_point=0, dtype=torch.qint8 + ) + self.set_weight_bias(wq, None, row_block_size, col_block_size) + + def _get_name(self): + return "SparseQuantizedLinearPackedParams" + + @torch.jit.export + def set_weight_bias( + self, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + row_block_size: Optional[int], + col_block_size: Optional[int], + ) -> None: + assert row_block_size is not None and col_block_size is not None + self._packed_params = torch.ops.sparse.qlinear_prepack( + weight, bias, row_block_size, col_block_size + ) + + @torch.jit.export + def _weight_bias(self): + (weight, bias, block_sizes) = torch.ops.sparse.qlinear_unpack( + self._packed_params + ) + return (weight, bias, block_sizes[0], block_sizes[1]) + + def forward(self, x): + return x + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "dtype"] = self.dtype + destination[prefix + "_packed_params"] = self._weight_bias() + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + assert version <= self._version + + self.dtype = state_dict.pop(prefix + "dtype") + weight, bias, row_block_size, col_block_size = state_dict.pop( + prefix + "_packed_params" + ) + self.set_weight_bias(weight, bias, row_block_size, col_block_size) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def __getstate__(self): + return self._packed_params, self.training, self.dtype + + @torch.jit.export + def __setstate__(self, state): + (self._packed_params, self.training, self.dtype) = state + + def __repr__(self): + return self._weight_bias().__repr__() + + +# TODO (zaf): Inherit from `quantized.Linear` (T83294430) +class Linear(torch.nn.Module): + r""" + A quantized sparse linear module with quantized tensor as inputs and outputs. + """ + + _version = 1 + _FLOAT_MODULE = torch.nn.Linear + + def __init__( + self, + in_features, + out_features, + row_block_size, + col_block_size, + bias=True, + dtype=torch.qint8, + ): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError( + "Only QINT8 is supported for Sparse Quantized Linear" + ) + + self.in_features = in_features + self.out_features = out_features + + if bias: + bias = torch.zeros(self.out_features, dtype=torch.float) + else: + bias = None + + qweight = torch._empty_affine_quantized( + [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8 + ) + self._packed_params = LinearPackedParams( + row_block_size=row_block_size, col_block_size=col_block_size, dtype=dtype + ) + self._packed_params.set_weight_bias( + qweight, bias, row_block_size, col_block_size + ) + self.scale = 1.0 + self.zero_point = 0 + + @classmethod + def _get_name(cls): + return "SparseQuantizedLinear" + + def extra_repr(self): + return ( + f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, " + f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}" + ) + + def __repr__(self): + return _hide_packed_params_repr(self, LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.sparse.qlinear( + x, self._packed_params._packed_params, self.scale, self.zero_point + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = torch.tensor(self.scale) + destination[prefix + "zero_point"] = torch.tensor(self.zero_point) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.scale = float(state_dict[prefix + "scale"]) + state_dict.pop(prefix + "scale") + + self.zero_point = int(state_dict[prefix + "zero_point"]) + state_dict.pop(prefix + "zero_point") + + state_dict.pop(prefix + "op_type") + + version = local_metadata.get("version", None) + assert version <= self._version + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias( + self, + w: torch.Tensor, + b: Optional[torch.Tensor], + row_block_size: Optional[int], + col_block_size: Optional[int], + ) -> None: + assert row_block_size is not None and col_block_size is not None + self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + r"""Create a quantized sparse module from a float module. + + We only care about the convert at this stage, no need for observers just yet. + + TODO(zaf): Need to add the sparse params to the qconfig + """ + assert type(mod) == cls._FLOAT_MODULE, ( + cls._get_name() + ".from_float only works for " + cls._FLOAT_MODULE.__name__ + ) + assert hasattr(mod, "sparse_params"), ( + "Expecting the Linear to have `sparse_params`. Make sure you have provided arguments " + 'in the `sparsifier.squash_mask(params_to_save=("sparse_block_shape",))` method.' + ) + sparse_block_shape = mod.sparse_params.get("sparse_block_shape", None) # type: ignore[operator, union-attr] + assert isinstance(sparse_block_shape, (tuple, list)) + assert len(sparse_block_shape) == 2 + # TODO: Need to add options to qconfig to avoid the calibration. + # TODO: Add calibration for the sparsity + assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" + activation_post_process = mod.activation_post_process + weight_post_process = mod.qconfig.weight() # type: ignore[operator, union-attr] + + # Assumption is that the weight is already sparsified by the + # `sparsifier.convert` + weight = mod.weight + + weight_post_process(weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[operator, union-attr] + assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" + w_sc, w_zp = weight_post_process.calculate_qparams() + if isinstance(w_zp, torch.Tensor): + assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" + else: + assert w_zp == 0, "Weight zero point must map to 0" + qweight = _quantize_weight(weight.float(), weight_post_process) + + row_block_size = mod.sparse_params["sparse_block_shape"][0] # type: ignore[index] + col_block_size = mod.sparse_params["sparse_block_shape"][1] # type: ignore[index] + qlinear = cls( + mod.in_features, + mod.out_features, + row_block_size, + col_block_size, + dtype=dtype, + ) + qlinear.set_weight_bias( + qweight, + mod.bias, + row_block_size, # type: ignore[arg-type] + col_block_size, # type: ignore[arg-type] + ) + qlinear.scale = float(act_scale) + qlinear.zero_point = int(act_zp) + return qlinear diff --git a/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70daa8fd9f361eeecb68375278785e0c4e8f3c33 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/nn/sparse/quantized/utils.py @@ -0,0 +1,63 @@ +import threading +from typing import Optional + + +__all__ = ["LinearBlockSparsePattern"] + + +def _is_valid_linear_block_sparse_pattern( + row_block_size: int, col_block_size: int +) -> bool: + return (row_block_size == 1 and col_block_size == 4) or ( + row_block_size == 8 and col_block_size == 1 + ) + + +# This is a stop-gap measure as current flow does not allow module +# specific block sparse pattern. +# Infact there is no way to convey sparse pattern via module config +# of quantization flow. Thus using the global context to convey +# sparsity pattern. +# Once the flow supports it, this should be removed. +class LinearBlockSparsePattern: + rlock = threading.RLock() + row_block_size: int = 1 + col_block_size: int = 4 + prev_row_block_size: int = 1 + prev_col_block_size: int = 4 + + def __init__(self, row_block_size: int = 1, col_block_size: int = 4): + assert _is_valid_linear_block_sparse_pattern(row_block_size, col_block_size) + LinearBlockSparsePattern.rlock.acquire() + LinearBlockSparsePattern.prev_row_block_size = ( + LinearBlockSparsePattern.row_block_size + ) + LinearBlockSparsePattern.prev_col_block_size = ( + LinearBlockSparsePattern.col_block_size + ) + LinearBlockSparsePattern.row_block_size = row_block_size + LinearBlockSparsePattern.col_block_size = col_block_size + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + backtrace: Optional[object], + ) -> None: + LinearBlockSparsePattern.row_block_size = ( + LinearBlockSparsePattern.prev_row_block_size + ) + LinearBlockSparsePattern.col_block_size = ( + LinearBlockSparsePattern.prev_col_block_size + ) + LinearBlockSparsePattern.rlock.release() + + @staticmethod + def block_size() -> tuple[int, int]: + return ( + LinearBlockSparsePattern.row_block_size, + LinearBlockSparsePattern.col_block_size, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..96d24a2cf2e75465ac23f22551a98c3292af81dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite.py @@ -0,0 +1,567 @@ +# mypy: allow-untyped-defs +from typing import Any, Callable, Optional, Union + +import torch +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.nn as nn +from torch.ao.quantization import prepare +from torch.ao.quantization.quantization_mappings import ( + get_default_compare_output_module_list, +) + + +NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { + nnqd.Linear, + nnq.Linear, + nnqd.LSTM, + nn.LSTM, +} + + +def _find_match( + str_list: Union[dict[str, Any], list[str]], + key_str: str, + postfix: str, +) -> Optional[str]: + split_str = key_str.split(".") + if split_str[-1] == postfix: + match_string = "".join(key_str.split(".")[0:-1]) + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + + # For matching "fc.weight" and "fc._packed_params._packed_params" + if postfix == "_packed_params": + match_string = "".join(key_str.split(".")[0:-2]) + if len(match_string) == 0: + return None + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + return None + else: + return None + + +def compare_weights( + float_dict: dict[str, Any], quantized_dict: dict[str, Any] +) -> dict[str, dict[str, torch.Tensor]]: + r"""Compare the weights of the float module with its corresponding quantized + module. Return a dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights. This dict can be used to compare and compute the quantization + error of the weights of float and quantized models. + + Example usage:: + + wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict()) + for key in wt_compare_dict: + print( + key, + compute_error( + wt_compare_dict[key]["float"], + wt_compare_dict[key]["quantized"].dequantize(), + ), + ) + + Args: + float_dict: state dict of the float model + quantized_dict: state dict of the quantized model + + Return: + weight_dict: dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights") + weight_dict: dict[str, dict] = {} + for key in quantized_dict: + match_key = _find_match(float_dict, key, "weight") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key] + continue + + # For matching "fc.weight" and "fc._packed_params._packed_params" + match_key = _find_match(float_dict, key, "_packed_params") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key][0] + + # For LSTM + split_str = key.split(".") + if split_str[-1] == "param" and split_str[-3] == "_all_weight_values": + layer = split_str[-2] + module_name = ".".join(split_str[:-3]) + float_weight_ih_key = module_name + ".weight_ih_l" + layer + float_weight_hh_key = module_name + ".weight_hh_l" + layer + if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[float_weight_ih_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0] + ) + weight_dict[key]["float"] = float_dict[float_weight_hh_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0] + ) + + return weight_dict + + +def _get_logger_dict_helper( + mod: nn.Module, + target_dict: dict[str, Any], + prefix: str = "", +) -> None: + r"""This is the helper function for get_logger_dict + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + target_dict: the dictionary used to save all logger stats + """ + + def get_prefix(prefix): + return prefix if prefix == "" else prefix + "." + + for name, child in mod.named_children(): + if isinstance(child, Logger): + target_dict[get_prefix(prefix) + "stats"] = child.stats + break + + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_logger_dict_helper(child, target_dict, module_prefix) + + +def get_logger_dict(mod: nn.Module, prefix: str = "") -> dict[str, dict]: + r"""Traverse the modules and save all logger stats into target dict. + This is mainly used for quantization accuracy debug. + + Type of loggers supported: + ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module, + OutputLogger: used to log the outputs of the modules + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + + Return: + target_dict: the dictionary used to save all logger stats + + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict") + + target_dict: dict[str, dict] = {} + _get_logger_dict_helper(mod, target_dict, prefix) + return target_dict + + +class Logger(nn.Module): + r"""Base class for stats logging""" + + def __init__(self): + super().__init__() + self.stats = {} + # We only insert observer if the op is quantized with static quantization, + # which is identified by activation_observer.dtype == quint8. This is needed + # when attaching Logger as observer for FX mode + self.dtype = torch.quint8 + + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + + +class ShadowLogger(Logger): + r"""Class used in Shadow module to record the outputs of the original and + shadow modules. + """ + + def __init__(self): + super().__init__() + self.stats["float"] = [] + self.stats["quantized"] = [] + + def forward(self, x, y): # type: ignore[override] + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if len(x) > 1: + x = x[0] + if len(y) > 1: + y = y[0] + self.stats["quantized"].append(x.detach()) + self.stats["float"].append(y.detach()) + + +class OutputLogger(Logger): + r"""Class used to log the outputs of the module""" + + def __init__(self): + super().__init__() + self.stats["tensor_val"] = [] + + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + self.stats["tensor_val"].append(x) + return x + + +def _convert_tuple_to_list(t: Any) -> Any: + return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t + + +def _dequantize_tensor_list(t: Any) -> Any: + return ( + [_dequantize_tensor_list(x) for x in t] + if type(t) is list + else t.dequantize() + if t.is_quantized + else t + ) + + +class Shadow(nn.Module): + r"""Shadow module attaches the float module to its matching quantized module + as the shadow. Then it uses Logger module to process the outputs of both + modules. + + Args: + q_module: module quantized from float_module that we want to shadow + float_module: float module used to shadow q_module + logger_cls: type of logger used to process the outputs of q_module and + float_module. ShadowLogger or custom loggers can be used. + """ + + def __init__(self, q_module, float_module, logger_cls): + super().__init__() + self.orig_module = q_module + self.shadow_module = float_module + self.dequant = nnq.DeQuantize() + self.logger = logger_cls() + + def forward(self, *x) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + xl = _convert_tuple_to_list(x) + output = self.orig_module(*xl) + xl_float = _dequantize_tensor_list(xl) + shadow_output = self.shadow_module(*xl_float) + self.logger(output, shadow_output) + return output + + def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add(x, y) + self.logger(output, shadow_output) + return output + + def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.add_scalar(x, y) + self.logger(output, shadow_output) + return output + + def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.mul(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.mul(x, y) + self.logger(output, shadow_output) + return output + + def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.mul_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.mul_scalar(x, y) + self.logger(output, shadow_output) + return output + + def cat(self, x: list[torch.Tensor], dim: int = 0) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.cat(x, dim) + x = [y.dequantize() for y in x] + shadow_output = self.shadow_module.cat(x, dim) + self.logger(output, shadow_output) + return output + + def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + output = self.orig_module.add_relu(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add_relu(x, y) + self.logger(output, shadow_output) + return output + + +def prepare_model_with_stubs( + float_module: nn.Module, + q_module: nn.Module, + module_swap_list: set[type], + logger_cls: Callable, +) -> None: + r"""Prepare the model by attaching the float module to its matching quantized + module as the shadow if the float module type is in module_swap_list. + + Example usage:: + + prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) + q_model(data) + ob_dict = get_logger_dict(q_model) + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + module_swap_list: list of float module types to attach the shadow + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.prepare_model_with_stubs" + ) + + float_module_children = dict(float_module.named_children()) + + reassign = {} + for name, mod in q_module.named_children(): + if name not in float_module_children: + continue + + float_mod = float_module_children[name] + + if type(float_mod) not in module_swap_list: + prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls) + + # Insert shadow module only if the module is not of the same type as + # the floating point module + if type(float_mod) in module_swap_list and not _is_identical_module_type( + mod, float_mod + ): + reassign[name] = Shadow(mod, float_mod, logger_cls) + + for key, value in reassign.items(): + q_module._modules[key] = value + + +def _is_identical_module_type(mod1, mod2): + # Compare if two modules have the same dtype + mod1_module_types = [type(mod) for mod in mod1.modules()] + mod2_module_types = [type(mod) for mod in mod2.modules()] + return mod1_module_types == mod2_module_types + + +def compare_model_stub( + float_model: nn.Module, + q_model: nn.Module, + module_swap_list: set[type], + *data, + logger_cls=ShadowLogger, +) -> dict[str, dict]: + r"""Compare quantized module in a model with its floating point counterpart, + feeding both of them the same input. Return a dict with key corresponding to + module names and each entry being a dictionary with two keys 'float' and + 'quantized', containing the output tensors of quantized and its matching + float shadow module. This dict can be used to compare and compute the module + level quantization error. + + This function first call prepare_model_with_stubs() to swap the quantized + module that we want to compare with the Shadow module, which takes quantized + module, corresponding float module and logger as input, and creates a forward + path inside to make the float module to shadow quantized module sharing the + same input. The logger can be customizable, default logger is ShadowLogger + and it will save the outputs of the quantized module and float module that + can be used to compute the module level quantization error. + + Example usage:: + + module_swap_list = [ + torchvision.models.quantization.resnet.QuantizableBasicBlock + ] + ob_dict = compare_model_stub(float_model, qmodel, module_swap_list, data) + for key in ob_dict: + print( + key, + compute_error( + ob_dict[key]["float"], ob_dict[key]["quantized"].dequantize() + ), + ) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + module_swap_list: list of float module types at which shadow modules will + be attached. + data: input data used to run the prepared q_model + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub") + prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls) + q_model(*data) + ob_dict = get_logger_dict(q_model) + return ob_dict + + +def get_matching_activations( + float_module: nn.Module, + q_module: nn.Module, +) -> dict[str, dict[str, torch.Tensor]]: + r"""Find the matching activation between float and quantized modules. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + + Return: + act_dict: dict with key corresponding to quantized module names and each + entry being a dictionary with two keys 'float' and 'quantized', containing + the matching float and quantized activations + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.get_matching_activations" + ) + float_dict = get_logger_dict(float_module) + quantized_dict = get_logger_dict(q_module) + act_dict: dict[str, dict] = {} + for key in quantized_dict: + if len(quantized_dict[key]["tensor_val"]) == 0: + continue + match_key = _find_match(sorted(float_dict, reverse=True), key, "stats") + if match_key is not None: + act_dict[key] = {} + act_dict[key]["float"] = float_dict[match_key]["tensor_val"] + act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"] + return act_dict + + +def prepare_model_outputs( + float_module: nn.Module, + q_module: nn.Module, + logger_cls=OutputLogger, + allow_list=None, +) -> None: + r"""Prepare the model by attaching the logger to both float module + and quantized module if they are in the allow_list. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.prepare_model_outputs" + ) + if allow_list is None: + allow_list = get_default_compare_output_module_list() + + qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None) + float_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare( + float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={} + ) + q_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare( + q_module, + inplace=True, + allow_list=allow_list, + observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + prepare_custom_config_dict={}, + ) + + +def compare_model_outputs( + float_model: nn.Module, + q_model: nn.Module, + *data, + logger_cls=OutputLogger, + allow_list=None, +) -> dict[str, dict[str, torch.Tensor]]: + r"""Compare output activations between float and quantized models at + corresponding locations for the same input. Return a dict with key corresponding + to quantized module names and each entry being a dictionary with two keys + 'float' and 'quantized', containing the activations of quantized model and + float model at matching locations. This dict can be used to compare and + compute the propagation quantization error. + + Example usage:: + + act_compare_dict = compare_model_outputs(float_model, qmodel, data) + for key in act_compare_dict: + print( + key, + compute_error( + act_compare_dict[key]["float"], + act_compare_dict[key]["quantized"].dequantize(), + ), + ) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + data: input data used to run the prepared float_model and q_model + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + + Return: + act_compare_dict: dict with key corresponding to quantized module names + and each entry being a dictionary with two keys 'float' and 'quantized', + containing the matching float and quantized activations + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite.compare_model_outputs" + ) + if allow_list is None: + allow_list = get_default_compare_output_module_list() + prepare_model_outputs(float_model, q_model, logger_cls, allow_list) + float_model(*data) + q_model(*data) + act_compare_dict = get_matching_activations(float_model, q_model) + return act_compare_dict diff --git a/.venv/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite_fx.py b/.venv/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..ec13839f3c9b75a15cdcf49c1f1c11f0e9862e14 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/ns/_numeric_suite_fx.py @@ -0,0 +1,1124 @@ +# mypy: allow-untyped-defs +""" +This module contains tooling to compare weights and activations +across models. Example usage:: + + import copy + import torch + import torch.ao.quantization.quantize_fx as quantize_fx + import torch.ao.ns._numeric_suite_fx as ns + + m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval() + mp = quantize_fx.prepare_fx(m, {"": torch.ao.quantization.default_qconfig}) + # We convert a copy because we need the original prepared model + # to be available for comparisons, and `quantize_fx.convert_fx` is inplace. + mq = quantize_fx.convert_fx(copy.deepcopy(mp)) + + # + # Comparing weights + # + + # extract weight pairs + weight_comparison = ns.extract_weights("a", mp, "b", mq) + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + weight_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # weight_comparison contains the weights from `mp` and `mq` stored + # in pairs, and can be used for further analysis. + + + # + # Comparing activations, with error propagation + # + + # add loggers + mp_ns, mq_ns = ns.add_loggers( + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_ns(datum) + mq_ns(datum) + + # extract intermediate activations + act_comparison = ns.extract_logger_info(mp_ns, mq_ns, ns.OutputLogger, "b") + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + + # + # Comparing activations, without error propagation + # + + # create shadow model + mp_shadows_mq = ns.add_shadow_loggers( + "a", copy.deepcopy(mp), "b", copy.deepcopy(mq), ns.OutputLogger + ) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_shadows_mq(datum) + + # extract intermediate activations + shadow_act_comparison = ns.extract_shadow_logger_info( + mp_shadows_mq, ns.OutputLogger, "b" + ) + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + shadow_act_comparison, "a", "b", torch.ao.ns.fx.utils.compute_sqnr, "sqnr" + ) + + # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + +""" + +import collections +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch.ao.quantization.quantize_fx as quantize_fx +import torch.nn as nn +from torch.ao.ns.fx.graph_matcher import get_matching_subgraph_pairs +from torch.ao.ns.fx.mappings import get_base_name_to_sets_of_related_ops +from torch.ao.ns.fx.n_shadows_utils import ( + _get_dedup_subgraphs, + create_add_loggers_graph, + create_n_transformed_and_logged_copies_of_subgraph, + create_results_comparison, + extract_weight_comparison, + group_results_by_subgraph, + OutputProp, + print_n_shadows_summary, + SHADOW_WRAPPER_NODE_NAME_PREFIX, +) +from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.backend_config.utils import ( + get_fusion_pattern_to_root_node_getter, +) +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr +from torch.ao.quantization.fx.match_utils import _find_matches +from torch.ao.quantization.fx.qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, +) +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .fx.graph_passes import add_loggers_to_model, create_a_shadows_b +from .fx.ns_types import NSNodeTargetType, NSResultsType, NSSingleResultValuesType +from .fx.utils import ( + get_target_type_str, + maybe_add_missing_fqns, + rekey_logger_info_on_node_name_of_model, +) +from .fx.weight_utils import extract_weight_from_node + + +if TYPE_CHECKING: + from torch.ao.quantization.qconfig import QConfigAny + +RNNReturnType = tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] + + +class OutputLogger(nn.Module): + """ + Base class for capturing intermediate values. + """ + + stats: list[torch.Tensor] + stats_rnn: list[RNNReturnType] + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + ref_node_name: str, + prev_node_name: str, + model_name: str, + ref_name: str, + prev_node_target_type: str, + ref_node_target_type: str, + results_type: str, + index_within_arg: int, + index_of_arg: int, + fqn: Optional[str], + qconfig_str: Optional[str] = "", + ): + super().__init__() + self.stats: list[torch.Tensor] = [] + self.stats_rnn: list[RNNReturnType] = [] + + # name of the node which was responsible for adding this logger + # Note: + # - if we are logging node outputs, this is the same as prev_node_name + # - if we are logging node inputs, this is the name of the node + # whose input this logger is logging. + # + # example, where logger1 is logging input of op1 and logger2 is logging + # the output of op1: + # + # x1 -> logger1 -> op1 -> logger2 -> x2 + # + # in this example, + # - logger1's prev_node_name is x1 and ref_node_name is op1 + # - logger2's prev_node_name is op1 and ref_node_name is op1 + self.ref_node_name = ref_node_name + # name of the node whose output this Logger is capturing + self.prev_node_name = prev_node_name + + # name of the model from which the node originated from + self.model_name = model_name + # reference name, used to match loggers from separate models + # to each other + self.ref_name = ref_name + # type of the target of the node whose output this logger is logging + self.prev_node_target_type = prev_node_target_type + # type of the target of the node which was responsible for adding this + # logger + self.ref_node_target_type = ref_node_target_type + # what kind of values are inside of stats + self.results_type = results_type + # index of this node within the arg of the input/output node + # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 + self.index_within_arg = index_within_arg + # index of this node within the args of the input/output node + # for example, in add(x1, x2), x2 would have index_of_arg == 1 + self.index_of_arg = index_of_arg + # fully qualified name + self.fqn = fqn + # if loggers are added before prepare_fx, but we do not want + # collect results of calibration, only results after convert_fx + # so, we add a flag to control whether this logger collects data + self.enabled = True + # string representation of qconfig + self.qconfig_str = qconfig_str + # this can be turned off to reduce memory usage during calibration + self.save_activations = True + + # Note: cannot annotate the type of x because TorchScript does not support + # the Union type. + def forward(self, x): + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + # TODO(future PR): consider designing this better, as the difference + # between these two flags is subtle and not obvious. + if not self.enabled: + return x + if not self.save_activations: + return x + # TODO(future PR): consider refactoring this to better reuse the parent + # class + if isinstance(x, torch.Tensor): + self.stats.append(x.detach()) + elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2: + new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach())) + self.stats_rnn.append(new_res) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != "training") and not k.startswith("_") + } + return f"OutputLogger({clean_dict})" + + +class OutputComparisonLogger(OutputLogger): + """ + Same as OutputLogger, but also requires the original activation + in order to calculate the comparison at calibration time + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO(future PR): make the comparison function configurable + self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr + self.comparison_fn_name = "sqnr" + # precalculated comparisons of logger output versus reference + self.comparisons = [] + # precalculated comparisons function + + def forward(self, x, x_ref): # type: ignore[override] + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if not self.enabled: + return x + assert isinstance(x, torch.Tensor), "non-tensor inputs not yet supported" + if self.save_activations: + # save the activation, for debugging + self.stats.append(x.detach()) + # save the comparison + self.comparisons.append(self.comparison_fn(x, x_ref)) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != "training") and not k.startswith("_") + } + return f"OutputComparisonLogger({clean_dict})" + + +class NSTracer(quantize_fx.QuantizationTracer): + """ + Just like a regular FX quantization tracer, but treats observers and fake_quantize + modules as leaf modules. + """ + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + # fmt: off + """ + """ # blank docblock to make autodoc happy + # fmt: on + if isinstance(m, torch.ao.quantization.ObserverBase): + return True + elif isinstance(m, torch.ao.quantization.FakeQuantizeBase): + return True + return super().is_leaf_module(m, module_qualified_name) + + +def _extract_weights_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument: list[tuple[Node, str]], + results: NSResultsType, + op_to_type_to_weight_extraction_fn: Optional[ + dict[str, dict[Callable, Callable]] + ] = None, +) -> None: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_weights_one_model" + ) + for node, ref_name in nodes_and_names_to_instrument: + res_type = NSSingleResultValuesType.WEIGHT.value + extracted_weight = extract_weight_from_node( + node, model, op_to_type_to_weight_extraction_fn + ) + if extracted_weight: + if ref_name not in results: + results[ref_name] = {res_type: {}} + results[ref_name][res_type][model_name] = [extracted_weight] + + +def _extract_weights_impl( + model_name_a: str, + gm_a: GraphModule, + model_name_b: str, + gm_b: GraphModule, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + op_to_type_to_weight_extraction_fn: Optional[ + dict[str, dict[Callable, Callable]] + ] = None, +) -> NSResultsType: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_weights_impl" + ) + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + + # split the subgraph pairs into one data structure for each model + nodes_and_names_to_instrument_a: list[tuple[Node, str]] = [] + nodes_and_names_to_instrument_b: list[tuple[Node, str]] = [] + for match_name, match in matched_subgraph_pairs.items(): + subgraph_a, subgraph_b = match + nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name)) + nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name)) + + # populate the results, one model at a time + results: NSResultsType = {} + _extract_weights_one_model( + model_name_a, + gm_a, + nodes_and_names_to_instrument_a, + results, + op_to_type_to_weight_extraction_fn, + ) + _extract_weights_one_model( + model_name_b, + gm_b, + nodes_and_names_to_instrument_b, + results, + op_to_type_to_weight_extraction_fn, + ) + + # fill in missing fqn entries + maybe_add_missing_fqns(results) + + # rekey on names of nodes in gm_b + results = rekey_logger_info_on_node_name_of_model(results, model_name_b) + + return results + + +def extract_weights( + model_name_a: str, + model_a: nn.Module, + model_name_b: str, + model_b: nn.Module, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + op_to_type_to_weight_extraction_fn: Optional[ + dict[str, dict[Callable, Callable]] + ] = None, +) -> NSResultsType: + """ + Extract weights from model A and model B, and return a comparison. + + Args: + model_name_a: string name of model A to use in results + model_a: model A + model_name_b: string name of model B to use in results + model_b: model B + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + op_to_type_to_weight_extraction_fn: optional override of function which extracts weight + from a type, subject to change + + Return: + NSResultsType, containing the weight comparisons + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights") + if base_name_to_sets_of_related_ops is None: + base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() + + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _extract_weights_impl( + model_name_a, + gm_a, + model_name_b, + gm_b, + base_name_to_sets_of_related_ops, + unmatchable_types_map, + op_to_type_to_weight_extraction_fn, + ) + + +def _add_loggers_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument_inputs: list[tuple[Node, str, str]], + nodes_and_names_to_instrument_outputs: list[tuple[Node, str, str]], + logger_cls: Callable, +) -> nn.Module: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._add_loggers_one_model" + ) + + # TODO(future PR): do not observe nodes we do not care + # about (both fp32, denylist, etc) + node_to_instrument_inputs_to_ref_name: dict[Node, tuple[str, str]] = {} + node_to_instrument_outputs_to_ref_name: dict[Node, tuple[str, str]] = {} + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs: + node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type) + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs: + node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type) + + model = add_loggers_to_model( + model, + node_to_instrument_inputs_to_ref_name, + node_to_instrument_outputs_to_ref_name, + logger_cls, + model_name, + ) + return model + + +def _add_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> tuple[nn.Module, nn.Module]: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl") + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + nodes_and_names_to_instrument_inputs_a = [] + nodes_and_names_to_instrument_inputs_b = [] + nodes_and_names_to_instrument_outputs_a = [] + nodes_and_names_to_instrument_outputs_b = [] + for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items(): + ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) + ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) + # Note: for matching inputs we use start_node, such as observing + # the input of linear in linear-relu + if should_log_inputs: + nodes_and_names_to_instrument_inputs_a.append( + (subgraph_a.start_node, match_name, ref_node_type_a) + ) + nodes_and_names_to_instrument_inputs_b.append( + (subgraph_b.start_node, match_name, ref_node_type_b) + ) + # Note: for matching activations we always use end_node, + # such as observing the output of relu in linear-relu + nodes_and_names_to_instrument_outputs_a.append( + (subgraph_a.end_node, match_name, ref_node_type_a) + ) + nodes_and_names_to_instrument_outputs_b.append( + (subgraph_b.end_node, match_name, ref_node_type_b) + ) + + new_model_a = _add_loggers_one_model( + name_a, + gm_a, + nodes_and_names_to_instrument_inputs_a, + nodes_and_names_to_instrument_outputs_a, + logger_cls, + ) + new_model_b = _add_loggers_one_model( + name_b, + gm_b, + nodes_and_names_to_instrument_inputs_b, + nodes_and_names_to_instrument_outputs_b, + logger_cls, + ) + return (new_model_a, new_model_b) + + +def add_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs: bool = False, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> tuple[nn.Module, nn.Module]: + """ + Instrument model A and model B with loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + + Return: + Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace. + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers") + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_loggers_impl( + name_a, + gm_a, + name_b, + gm_b, + logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + unmatchable_types_map=unmatchable_types_map, + ) + + +def _extract_logger_info_one_model( + model: nn.Module, + results: NSResultsType, + logger_cls: Callable, +) -> None: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._extract_logger_info_one_model" + ) + for _gm_name, mod in model.named_modules(): + # TODO(future PR): better check when scripted + is_logger = isinstance(mod, logger_cls) or ( # type: ignore[arg-type] + isinstance(mod, torch.jit.RecursiveScriptModule) + and mod.original_name == "OutputLogger" + ) + if is_logger: + key = mod.ref_name + if key not in results: + results[key] = {} + assert mod.model_name not in results[key], ( + f"{mod.model_name} is already present in results" + ) + if mod.results_type not in results[key]: + results[key][mod.results_type] = {} + if mod.model_name not in results[key][mod.results_type]: + results[key][mod.results_type][mod.model_name] = [] + stats_to_use = mod.stats + if len(mod.stats_rnn) > 0: + stats_to_use = mod.stats_rnn + data = { + "type": mod.results_type, + "values": stats_to_use, + "ref_node_name": mod.ref_node_name, + "ref_node_target_type": mod.ref_node_target_type, + "prev_node_name": mod.prev_node_name, + "prev_node_target_type": mod.prev_node_target_type, + "index_within_arg": mod.index_within_arg, + "index_of_arg": mod.index_of_arg, + "fqn": mod.fqn, + "qconfig_str": mod.qconfig_str, + } + if hasattr(mod, "comparisons"): + data["comparisons"] = mod.comparisons + data["comparison_fn_name"] = mod.comparison_fn_name + else: + data["comparisons"] = [] + data["comparison_fn_name"] = "" + results[key][mod.results_type][mod.model_name].append(data) + # ensure the list stays sorted + results[key][mod.results_type][mod.model_name].sort( + key=lambda res: f"{res['index_of_arg']}:{res['index_within_arg']}" + ) + + +# TODO(future PR): align on naming +# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs` +def extract_logger_info( + model_a: nn.Module, + model_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in `model_a` and `model_b`, and extract the logged + information. + + Args: + model_a: model A + model_b: model B + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.extract_logger_info" + ) + results: NSResultsType = {} + for model in (model_a, model_b): + _extract_logger_info_one_model(model, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names + ) + return results + + +def _add_shadow_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + node_type_to_io_type_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> nn.Module: + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx._add_shadow_loggers_impl" + ) + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, unmatchable_types_map + ) + gm_a_shadows_b = create_a_shadows_b( + name_a, + gm_a, + name_b, + gm_b, + matched_subgraph_pairs, + logger_cls, + should_log_inputs=should_log_inputs, + node_type_to_io_type_map=node_type_to_io_type_map, + ) + return gm_a_shadows_b + + +def add_shadow_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs: bool = False, + base_name_to_sets_of_related_ops: Optional[dict[str, set[NSNodeTargetType]]] = None, + node_type_to_io_type_map: Optional[dict[str, set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[dict[str, set[NSNodeTargetType]]] = None, +) -> nn.Module: + """ + Instrument model A and model B with shadow loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + should_log_inputs: whether to log inputs + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.add_shadow_loggers" + ) + # TODO(future PR): expose these + skipped_module_names: list[str] = [] + skipped_module_classes: list[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr( + model_a, "node_name_to_scope" + ) + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr( + model_b, "node_name_to_scope" + ) + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_shadow_loggers_impl( + name_a, + gm_a, + name_b, + gm_b, + logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + node_type_to_io_type_map=node_type_to_io_type_map, + unmatchable_types_map=unmatchable_types_map, + ) + + +def extract_shadow_logger_info( + model_a_shadows_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in a shadow model, and extract the logged + information. + + Args: + model_a_shadows_b: shadow model + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once( + "quantization_api._numeric_suite_fx.extract_shadow_logger_info" + ) + results: NSResultsType = collections.defaultdict(dict) + _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names + ) + return dict(results) + + +def extend_logger_results_with_comparison( + results: NSResultsType, + model_name_1: str, + model_name_2: str, + comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + comparison_name: str, +) -> None: + """ + Compares the logged values from `model_name_2` against the corresponding + values in `model_name_1`, using `comparison_fn`. Records the result + in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace. + + Args: + results: the result data structure from `extract_logger_info` or + `extract_shadow_logger_info`. + model_name_1: string name of model 1 + model_name_2: string name of model 2 + comparison_fn: function to compare two Tensors + comparison_name: string name of model to use for + layer names in the output + """ + for results_type_to_results in results.values(): + for model_name_to_results in results_type_to_results.values(): + assert model_name_1 in model_name_to_results, ( + f"{model_name_1} not found in results" + ) + assert model_name_2 in model_name_to_results, ( + f"{model_name_2} not found in results" + ) + + results_1 = model_name_to_results[model_name_1] + results_2 = model_name_to_results[model_name_2] + + for result_2 in results_2: + index_within_arg_2 = result_2["index_within_arg"] + index_of_arg_2 = result_2["index_of_arg"] + # find corresponding result_1 + result_1 = None + for cur_result_1 in results_1: + index_within_arg_1 = cur_result_1["index_within_arg"] + index_of_arg_1 = cur_result_1["index_of_arg"] + if (index_within_arg_1 == index_within_arg_2) and ( + index_of_arg_1 == index_of_arg_2 + ): + result_1 = cur_result_1 + break + assert result_1 is not None + + values_1 = result_1["values"] + values_2 = result_2["values"] + result_2[comparison_name] = [] + for value_1, value_2 in zip(values_1, values_2): + comparison_result = comparison_fn(value_1, value_2) + result_2[comparison_name].append(comparison_result) + + +def prepare_n_shadows_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_multi_mapping: QConfigMultiMapping, + backend_config: BackendConfig, + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[dict[str, Any]] = None, + custom_tracer: Any = None, +) -> GraphModule: + """ + Given a model with a graph with M ops such as + + + args_kwargs_m -> op_m -> output_m + + + And a set of N qconfigs for each op, creates a new model, with + each of the subgraph of `op_m` transformed into + + .. code:: + + |---------> op_m_n -> log_m_n + | / + args_kwargs_m ---------> op_m -> log_m_0 + + Where op_m_n is op_m wrapped in a submodule and transformed with + qconfig_n, and its inner graph looks like + + .. code:: + + args_m -------- op_m_prepared_with_qconfig_n -> out_m_n + / + kwargs_m --- + + This is useful for testing different quantization of multiple layers in + a single pass through the model. + + High level TODOs for future PRs: + * figure out a better way to name the output structure + * return a results data structure instead of printing it out + * add examples to docblocks + """ + + if custom_tracer is None: + tracer = quantize_fx.QuantizationTracer([], []) + else: + tracer = custom_tracer + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: list[str] = [] + standalone_module_classes: list[type] = [] + custom_module_classes: list[type] = [] + matches = _find_matches( + mt.graph, + modules, + patterns, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + subgraphs_dedup: dict[str, list[Node]] = _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + # TODO(future PR): deduplicate repeating entries + list_of_node_name_to_qconfig: list[dict[str, QConfigAny]] = [] + for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list: + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope + ) + list_of_node_name_to_qconfig.append(node_name_to_qconfig) + + # For each region in the model, do the following: + # For each qconfig for that region, do the following: + # 1. create a copy of the region wrapped in a module + # 2. pass original args, original kwargs, and expected output to module + # 3. add an output comparison logger and hook it up to compare + # actual output to expected output + # 4. run `prepare_fx` on the module + for subgraph_idx, (match_name, nodes_in_this_subgraph) in enumerate( + subgraphs_dedup.items() + ): + create_n_transformed_and_logged_copies_of_subgraph( + mt, + subgraph_idx, + match_name, + nodes_in_this_subgraph, + qconfig_multi_mapping.qconfig_mappings_list, + list_of_node_name_to_qconfig, + custom_prepare_fn, + custom_prepare_kwargs, # type: ignore[arg-type] + ) + + return mt + + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _prepare_n_shadows_add_loggers_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> torch.nn.Module: + r""" + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + + This creates a model which provides logging for the following + problem: if we quantize `model` with `qconfig_mapping` and feed + the same input through both models, log the comparisons of + corresponding intermediate layers. + + The problem is solved with a single model. Specifically, we + partition `model` into N subgraphs, create a copy of each relevant + subgraph, wrap it in a module, apply the quantization API to that + module, and hook up loggers to measure the comparisons. + + Example starting graph: + + x0 -> op0 -> x1 -> op1 -> x2 + + Example config: quantize op0 to int8, do nothing to op1. + The following graph will be created: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog + + Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized + to int8, op1_0 is op1 (appearing in the graph twice), log is a logger, + and clog is a comparison logger. + """ + + tracer = quantize_fx.QuantizationTracer([], []) + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope # type: ignore[assignment] + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: list[str] = [] + standalone_module_classes: list[type] = [] + custom_module_classes: list[type] = [] + matches = _find_matches( + mt.graph, + modules, + patterns, + root_node_getter_mapping, + standalone_module_names, + standalone_module_classes, + custom_module_classes, + ) + subgraphs_dedup: dict[str, list[Node]] = _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope + ) + + # Now, mutate the graph to be the add_loggers graph with propagation + # error. + create_add_loggers_graph(mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig) + + return mt + + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _n_shadows_compare_weights( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> NSResultsType: + """ + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + """ + qconfig_multi_mapping = QConfigMultiMapping.from_list_qconfig_mapping( + [qconfig_mapping] + ) + mp = prepare_n_shadows_model( + model, example_inputs, qconfig_multi_mapping, backend_config + ) + # passing inputs through the model is necessary to populate + # observers which observe weights with real values + mp(*example_inputs) + mq = convert_n_shadows_model(mp) + weight_comparison = extract_weight_comparison(mq) + return weight_comparison + + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None: + """ + Sets the `enabled` setting on a `model`'s loggers + """ + for _, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.enabled = enabled + + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_save_activations( + model: torch.nn.Module, + save_activations: bool, +) -> None: + """ + Sets the `save_activations` setting on a `model`'s loggers + """ + for _name, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.save_activations = save_activations + + +def convert_n_shadows_model( + model: GraphModule, + custom_convert_fn: Optional[Callable] = None, + custom_convert_kwargs: Optional[dict[str, Any]] = None, +) -> GraphModule: + """ + Given a model from `prepare_n_shadows_model`, runs `convert_fx` + on each shadow submodule. + """ + for node in model.graph.nodes: + # TODO(future PR): consider matching in a safer way than + # node name string match + if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX): + orig_mod = getattr(model, node.name) + if custom_convert_fn is None: + converted_mod = torch.ao.quantization.quantize_fx.convert_fx(orig_mod) + else: + if custom_convert_kwargs is None: + custom_convert_kwargs = {} + converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs) + setattr(model, node.name, converted_mod) + + return model + + +def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType: + """ + Extracts logger results from `model`. + """ + results: NSResultsType = {} + _extract_logger_info_one_model(model, results, OutputLogger) + return results + + +def print_comparisons_n_shadows_model(results: NSResultsType) -> None: + """ + Prints a summary of extracted `results`. + """ + results_grouped = group_results_by_subgraph(results) + results_comparison = create_results_comparison(results_grouped) + print_n_shadows_summary(results_comparison) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52fc301befd34642d51f1c27e07600a1f3ef26ff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/__init__.py @@ -0,0 +1,23 @@ +# Variables +from ._mappings import ( + get_dynamic_sparse_quantized_mapping, + get_static_sparse_quantized_mapping, +) + +# Scheduler +from .scheduler.base_scheduler import BaseScheduler +from .scheduler.cubic_scheduler import CubicSL +from .scheduler.lambda_scheduler import LambdaSL + +# Sparsifier +from .sparsifier.base_sparsifier import BaseSparsifier +from .sparsifier.nearly_diagonal_sparsifier import NearlyDiagonalSparsifier + +# Parametrizations +from .sparsifier.utils import ( + FakeSparsity, + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) +from .sparsifier.weight_norm_sparsifier import WeightNormSparsifier diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..acf2126c1e4f4556c9be1a19e83b85c6a1e8be04 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -0,0 +1,476 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections import defaultdict +from typing import Any, Optional + +import torch +from torch import nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +__all__ = ["ActivationSparsifier"] + + +class ActivationSparsifier: + r""" + The Activation sparsifier class aims to sparsify/prune activations in a neural + network. The idea is to attach the sparsifier to a layer (or layers) and it + zeroes out the activations based on the mask_fn (or sparsification function) + input by the user. + The mask_fn is applied once all the inputs are aggregated and reduced i.e. + mask = mask_fn(reduce_fn(aggregate_fn(activations))) + + Note:: + The sparsification mask is computed on the input **before it goes through the attached layer**. + + Args: + model (nn.Module): + The model whose layers will be sparsified. The layers that needs to be + sparsified should be added separately using the register_layer() function + aggregate_fn (Optional, Callable): + default aggregate_fn that is used if not specified while registering the layer. + specifies how inputs should be aggregated over time. + The aggregate_fn should usually take 2 torch tensors and return the aggregated tensor. + Example + def add_agg_fn(tensor1, tensor2): return tensor1 + tensor2 + reduce_fn (Optional, Callable): + default reduce_fn that is used if not specified while registering the layer. + reduce_fn will be called on the aggregated tensor i.e. the tensor obtained after + calling agg_fn() on all inputs. + Example + def mean_reduce_fn(agg_tensor): return agg_tensor.mean(dim=0) + mask_fn (Optional, Callable): + default mask_fn that is used to create the sparsification mask using the tensor obtained after + calling the reduce_fn(). This is used by default if a custom one is passed in the + register_layer(). + Note that the mask_fn() definition should contain the sparse arguments that is passed in sparse_config + arguments. + features (Optional, list): + default selected features to sparsify. + If this is non-empty, then the mask_fn will be applied for each feature of the input. + For example, + mask = [mask_fn(reduce_fn(aggregated_fn(input[feature])) for feature in features] + feature_dim (Optional, int): + default dimension of input features. Again, features along this dim will be chosen + for sparsification. + sparse_config (Dict): + Default configuration for the mask_fn. This config will be passed + with the mask_fn() + + Example: + >>> # xdoctest: +SKIP + >>> model = SomeModel() + >>> act_sparsifier = ActivationSparsifier(...) # init activation sparsifier + >>> # Initialize aggregate_fn + >>> def agg_fn(x, y): + >>> return x + y + >>> + >>> # Initialize reduce_fn + >>> def reduce_fn(x): + >>> return torch.mean(x, dim=0) + >>> + >>> # Initialize mask_fn + >>> def mask_fn(data): + >>> return torch.eye(data.shape).to(data.device) + >>> + >>> + >>> act_sparsifier.register_layer( + ... model.some_layer, + ... aggregate_fn=agg_fn, + ... reduce_fn=reduce_fn, + ... mask_fn=mask_fn, + ... ) + >>> + >>> # start training process + >>> for _ in [...]: + >>> # epoch starts + >>> # model.forward(), compute_loss() and model.backwards() + >>> # epoch ends + >>> act_sparsifier.step() + >>> # end training process + >>> sparsifier.squash_mask() + """ + + def __init__( + self, + model: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + self.model = model + self.defaults: dict[str, Any] = defaultdict() + self.defaults["sparse_config"] = sparse_config + + # functions + self.defaults["aggregate_fn"] = aggregate_fn + self.defaults["reduce_fn"] = reduce_fn + self.defaults["mask_fn"] = mask_fn + + # default feature and feature_dim + self.defaults["features"] = features + self.defaults["feature_dim"] = feature_dim + + self.data_groups: dict[str, dict] = defaultdict( + dict + ) # contains all relevant info w.r.t each registered layer + + self.state: dict[str, Any] = defaultdict(dict) # layer name -> mask + + @staticmethod + def _safe_rail_checks(args): + """Makes sure that some of the functions and attributes are not passed incorrectly""" + + # if features are not None, then feature_dim must not be None + features, feature_dim = args["features"], args["feature_dim"] + if features is not None: + assert feature_dim is not None, "need feature dim to select features" + + # all the *_fns should be callable + fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"] + for key in fn_keys: + fn = args[key] + assert callable(fn), "function should be callable" + + def _aggregate_hook(self, name): + """Returns hook that computes aggregate of activations passing through.""" + + # gather some data + feature_dim = self.data_groups[name]["feature_dim"] + features = self.data_groups[name]["features"] + agg_fn = self.data_groups[name]["aggregate_fn"] + + def hook(module, input) -> None: + input_data = input[0] + + data = self.data_groups[name].get("data") # aggregated data + if features is None: + # no features associated, data should not be a list + if data is None: + data = torch.zeros_like(input_data) + self.state[name]["mask"] = torch.ones_like(input_data) + out_data = agg_fn(data, input_data) + else: + # data should be a list [aggregated over each feature only] + if data is None: + out_data = [ + 0 for _ in range(0, len(features)) + ] # create one incase of 1st forward + self.state[name]["mask"] = [0 for _ in range(0, len(features))] + else: + out_data = data # a list + + # compute aggregate over each feature + for feature_idx in range(len(features)): + # each feature is either a list or scalar, convert it to torch tensor + feature_tensor = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + data_feature = torch.index_select( + input_data, feature_dim, feature_tensor + ) + if data is None: + curr_data = torch.zeros_like(data_feature) + self.state[name]["mask"][feature_idx] = torch.ones_like( + data_feature + ) + else: + curr_data = data[feature_idx] + out_data[feature_idx] = agg_fn(curr_data, data_feature) + self.data_groups[name]["data"] = out_data + + return hook + + def register_layer( + self, + layer: nn.Module, + aggregate_fn=None, + reduce_fn=None, + mask_fn=None, + features=None, + feature_dim=None, + **sparse_config, + ): + r""" + Registers a layer for sparsification. The layer should be part of self.model. + Specifically, registers a pre-forward hook to the layer. The hook will apply the aggregate_fn + and store the aggregated activations that is input over each step. + + Note:: + - There is no need to pass in the name of the layer as it is automatically computed as per + the fqn convention. + + - All the functions (fn) passed as argument will be called at a dim, feature level. + """ + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the model" # satisfy mypy + + if name in self.data_groups: # unregister layer if already present + warnings.warn( + "layer already attached to the sparsifier, deregistering the layer and registering with new config" + ) + self.unregister_layer(name=name) + + local_args = copy.deepcopy(self.defaults) + update_dict = { + "aggregate_fn": aggregate_fn, + "reduce_fn": reduce_fn, + "mask_fn": mask_fn, + "features": features, + "feature_dim": feature_dim, + "layer": layer, + } + local_args.update( + (arg, val) for arg, val in update_dict.items() if val is not None + ) + local_args["sparse_config"].update(sparse_config) + + self._safe_rail_checks(local_args) + + self.data_groups[name] = local_args + agg_hook = layer.register_forward_pre_hook(self._aggregate_hook(name=name)) + + self.state[name]["mask"] = ( + None # mask will be created when model forward is called. + ) + + # attach agg hook + self.data_groups[name]["hook"] = agg_hook + + # for serialization purposes, we know whether aggregate_hook is attached + # or sparsify_hook() + self.data_groups[name]["hook_state"] = "aggregate" # aggregate hook is attached + + def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None): + """ + Returns mask associated to the layer. + + The mask is + - a torch tensor is features for that layer is None. + - a list of torch tensors for each feature, otherwise + + Note:: + The shape of the mask is unknown until model.forward() is applied. + Hence, if get_mask() is called before model.forward(), an + error will be raised. + """ + assert name is not None or layer is not None, ( + "Need at least name or layer obj to retrieve mask" + ) + + if name is None: + assert layer is not None + name = module_to_fqn(self.model, layer) + assert name is not None, "layer not found in the specified model" + + if name not in self.state: + raise ValueError("Error: layer with the given name not found") + + mask = self.state[name].get("mask", None) + + if mask is None: + raise ValueError( + "Error: shape unknown, call layer() routine at least once to infer mask" + ) + return mask + + def unregister_layer(self, name): + """Detaches the sparsifier from the layer""" + + # detach any hooks attached + self.data_groups[name]["hook"].remove() + + # pop from the state dict + self.state.pop(name) + + # pop from the data groups + self.data_groups.pop(name) + + def step(self): + """Internally calls the update_mask() function for each layer""" + with torch.no_grad(): + for name, configs in self.data_groups.items(): + data = configs["data"] + self.update_mask(name, data, configs) + + self.data_groups[name].pop("data") # reset the accumulated data + + def update_mask(self, name, data, configs): + """ + Called for each registered layer and does the following- + 1. apply reduce_fn on the aggregated activations + 2. use mask_fn to compute the sparsification mask + + Note: + the reduce_fn and mask_fn is called for each feature, dim over the data + """ + mask = self.get_mask(name) + sparse_config = configs["sparse_config"] + features = configs["features"] + reduce_fn = configs["reduce_fn"] + mask_fn = configs["mask_fn"] + if features is None: + data = reduce_fn(data) + mask.data = mask_fn(data, **sparse_config) + else: + for feature_idx in range(len(features)): + data_feature = reduce_fn(data[feature_idx]) + mask[feature_idx].data = mask_fn(data_feature, **sparse_config) + + def _sparsify_hook(self, name): + """Returns hook that applies sparsification mask to input entering the attached layer""" + mask = self.get_mask(name) + features = self.data_groups[name]["features"] + feature_dim = self.data_groups[name]["feature_dim"] + + def hook(module, input): + input_data = input[0] + if features is None: + # apply to all the features + return input_data * mask + else: + # apply per feature, feature_dim + for feature_idx in range(0, len(features)): + feature = ( + torch.Tensor([features[feature_idx]]) + .long() + .to(input_data.device) + ) + sparsified = ( + torch.index_select(input_data, feature_dim, feature) + * mask[feature_idx] + ) + input_data.index_copy_(feature_dim, feature, sparsified) + return input_data + + return hook + + def squash_mask(self, attach_sparsify_hook=True, **kwargs): + """ + Unregisters aggregate hook that was applied earlier and registers sparsification hooks if + attach_sparsify_hook = True. + """ + for name, configs in self.data_groups.items(): + # unhook agg hook + configs["hook"].remove() + configs.pop("hook") + self.data_groups[name]["hook_state"] = "None" + if attach_sparsify_hook: + configs["hook"] = configs["layer"].register_forward_pre_hook( + self._sparsify_hook(name) + ) + configs["hook_state"] = ( + "sparsify" # signals that sparsify hook is now attached + ) + + def _get_serializable_data_groups(self): + """Exclude hook and layer from the config keys before serializing + + TODO: Might have to treat functions (reduce_fn, mask_fn etc) in a different manner while serializing. + For time-being, functions are treated the same way as other attributes + """ + data_groups: dict[str, Any] = defaultdict() + for name, config in self.data_groups.items(): + new_config = { + key: value + for key, value in config.items() + if key not in ["hook", "layer"] + } + data_groups[name] = new_config + return data_groups + + def _convert_mask(self, states_dict, sparse_coo=True): + r"""Converts the mask to sparse coo or dense depending on the `sparse_coo` argument. + If `sparse_coo=True`, then the mask is stored as sparse coo else dense tensor + """ + states = copy.deepcopy(states_dict) + for state in states.values(): + if state["mask"] is not None: + if isinstance(state["mask"], list): + for idx in range(len(state["mask"])): + if sparse_coo: + state["mask"][idx] = state["mask"][idx].to_sparse_coo() + else: + state["mask"][idx] = state["mask"][idx].to_dense() + else: + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + return states + + def state_dict(self) -> dict[str, Any]: + r"""Returns the state of the sparsifier as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a dictionary containing all config information for each + layer + * defaults - the default config while creating the constructor + """ + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return {"state": state, "data_groups": data_groups, "defaults": self.defaults} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + """ + state = state_dict["state"] + data_groups, defaults = state_dict["data_groups"], state_dict["defaults"] + + self.__set_state__( + {"state": state, "data_groups": data_groups, "defaults": defaults} + ) + + def __get_state__(self) -> dict[str, Any]: + data_groups = self._get_serializable_data_groups() + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": data_groups, + } + + def __set_state__(self, state: dict[str, Any]) -> None: + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert mask to dense tensor + self.__dict__.update(state) + + # need to attach layer and hook info into the data_groups + for name, config in self.data_groups.items(): + # fetch layer + layer = fqn_to_module(self.model, name) + assert layer is not None # satisfy mypy + + # if agg_mode is True, then layer in aggregate mode + if "hook_state" in config and config["hook_state"] == "aggregate": + hook = layer.register_forward_pre_hook(self._aggregate_hook(name)) + + elif "hook_state" in config and config["hook_state"] == "sparsify": + hook = layer.register_forward_pre_hook(self._sparsify_hook(name)) + + config["layer"] = layer + config["hook"] = hook # type: ignore[possibly-undefined] + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for name, config in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(config.keys()): + if key in ["data", "hook", "reduce_fn", "mask_fn", "aggregate_fn"]: + continue + format_string += f"\t {key}: {config[key]}\n" + format_string += ")" + return format_string diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7564fe408b36e5fb62eb4cb2272ef432095981 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py @@ -0,0 +1,6 @@ +from .base_data_scheduler import BaseDataScheduler + + +__all__ = [ + "BaseDataScheduler", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..672903e8f058cbb7299e90b7728bf6e36c52e7b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py @@ -0,0 +1,197 @@ +# mypy: allow-untyped-defs +import abc +import warnings +import weakref +from functools import wraps + +from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier + + +__all__ = ["BaseDataScheduler"] + + +class BaseDataScheduler: + r""" + The BaseDataScheduler is the abstract scheduler class specifically for the + BaseDataSparsifier class. This class controls a specific hyperparameter of + the sparsifier class and varies it across the training process (or across time). + + Args: + data_sparsifier (instance of BaseDataSparsifier) + Implemented class data sparsifier class wherein the update_mask is implemented + schedule_param (str) + A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied + last_epoch (int, default=-1) + This is specifically is passed when training needs to be resumed from a particular + point. + verbose (bool, default=False) + Verbosity of the BaseDataScheduler + + The *get_hyperparam()* function needs to be implemented by the user. + """ + + def __init__( + self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False + ): + # Attach sparsifier + if not isinstance(data_sparsifier, BaseDataSparsifier): + raise TypeError( + f"{type(data_sparsifier).__name__} is not an instance of torch.ao.pruning.BaseDataSparsifier" + ) + self.data_sparsifier = data_sparsifier + self.schedule_param = schedule_param + + # Initialize epoch and base hyper-params + self.base_param = { + name: config.get(schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.data_sparsifier.step = with_counter(self.data_sparsifier.step) # type: ignore[assignment] + self.data_sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sp_called_within_step: bool = False # sp -> schedule parameter + self.step() + + @abc.abstractmethod + def get_schedule_param(self): + r""" + Abstract method that needs to be implemented by the child class. + The expected return type should is a dictionary of name to schedule_param value + The returned values will be updated in sparsifier when the scheduler step() function + is called. + + Example: + >>> def get_schedule_param(self): + ... new_param = {} + ... for name in self.sparsifier.data_groups.keys(): + ... new_param[name] = ( + ... self.sparsifier.data_groups[name][self.schedule_param] * 0.5 + ... ) + ... return new_param + + When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] + would be halved + """ + raise NotImplementedError + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + format_string += "\n" + format_string += f"Data Sparsifier {self.data_sparsifier}\n" + format_string += f" {self.schedule_param}: {self.base_param}\n" + format_string += ")" + return format_string + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + + Note: + The scheduler class does not track the state of the data_sparsifier. + Make sure to store the state of the sparsifier before storing the + state of the scheduler + """ + return { + key: value + for key, value in self.__dict__.items() + if key != "data_sparsifier" + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Note: + Remember to restore the state of the data_sparsifier before the scheduler. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_param(self): + return self._last_param + + def step(self): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.data_sparsifier.step, "_with_counter"): + warnings.warn( + "Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `data_sparsifier.step()` before " + "`scheduler.step()`.", + UserWarning, + ) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn( + "Detected call of `scheduler.step()` before `data_sparsifier.step()`. " + "You have to make sure you run the data_sparsifier.step() BEFORE any " + "calls to the scheduler.step().", + UserWarning, + ) + self._step_count += 1 + + class _enable_get_sp_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sp_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sp_called_within_step = False + + with _enable_get_sp_call(self): + self.last_epoch += 1 + updated_scheduler_params = self.get_schedule_param() + + for name, param in updated_scheduler_params.items(): + self.data_sparsifier.data_groups[name][self.schedule_param] = param + if self.verbose: + print(f"Adjusting {self.schedule_param} for group {name} to {param}") + + self._last_param = { + name: config.get(self.schedule_param, None) + for name, config in self.data_sparsifier.data_groups.items() + } + self.data_sparsifier.enable_mask_update = True diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b5b9b96ec96fffdb0b66e21686a927a0c41b4a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__init__.py @@ -0,0 +1,8 @@ +from .base_data_sparsifier import BaseDataSparsifier +from .data_norm_sparsifier import DataNormSparsifier + + +__all__ = [ + "BaseDataSparsifier", + "DataNormSparsifier", +] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..3dea01586a2b3cc7ed711e54b580752761008368 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -0,0 +1,331 @@ +# mypy: allow-untyped-defs +import abc +import copy +import sys +import warnings +from collections import defaultdict +from typing import Any, Optional + +import torch +from torch import nn +from torch.ao.pruning.sparsifier import base_sparsifier, utils +from torch.nn.utils import parametrize + + +if not sys.warnoptions: + # to suppress repeated warnings when being used in a training loop. + warnings.simplefilter("once") + +__all__ = ["BaseDataSparsifier"] + +EMBEDDING_TYPES = { + nn.Embedding, + nn.EmbeddingBag, +} + +SUPPORTED_TYPES = { + torch.Tensor, + nn.Parameter, + *EMBEDDING_TYPES, +} + + +class _Container(nn.Module): + pass + + +class BaseDataSparsifier(base_sparsifier.BaseSparsifier): + r""" + Base Data Sparsifier class for all Data sparsifiers. + The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above) + to prepare for sparsification. + In this case, mask (and parametrizations) is owned by the class and not by the user. + Specifically, the container object inside the class maintains the mask and parametrizations of the input data + + Args: + data_list (list of tuples) + list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES + for type of data. Internally, a container module handles the data sparsification. + + defaults (dict) + default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + Example:: + >>> # xdoctest: +SKIP + >>> data_list = [('tensor_1', torch.randn(3,3)), ('tensor_2', torch.randn(4,4))] + >>> defaults = {'sparsity_level': 0.7} + >>> sparsifier = DerivedDataSparsifier(data_list = data_list, **defaults) # Some sparsifier that inherits BaseDataSparsifier + >>> new_tensor_to_add = {'name': 'tensor_3', 'data': torch.randn(5,5), 'sparsity_level': 0.3} + >>> sparsifier.add_data(**new_tensor_to_add) + >>> # tensor_1 and tensor_2 will have sparsity_level of 0.7 but tensor_3 will have sparsity_level=0.3 + """ + + def __init__(self, data_list: Optional[list[tuple[str, Any]]] = None, **defaults): + super().__init__(defaults=defaults) + + self._container = _Container() + + self.data_groups: dict[str, dict] = defaultdict(dict) # name -> {**config} + if data_list is not None: + # add data with default config here + [self.add_data(name, data, **self.defaults) for name, data in data_list] + + def prepare(self, model, config): + raise NotImplementedError("this function is undefined for this class") + + def _extract_weight(self, data): + # extract the weight parameter instead of underlying data + if type(data) in [torch.Tensor, nn.Parameter]: + return data + elif type(data) in EMBEDDING_TYPES: + return data.weight + + def add_data(self, name: str, data, reuse_mask=True, **config): + r"""Configures and parametrizes the internal container model with name and data. + + **Note**: + 1. If the data with name already exists, it replaces the data. + 2. While replacing, the old mask is reused when `reuse_mask=True` + 3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data. + 4. By default, the config of the replaced data is used as config for the replacing data, unless something + is specified in the config dictionary. + """ + assert type(data) in SUPPORTED_TYPES, ( + "specified data type not supported at the moment" + ) + local_args = copy.deepcopy(self.defaults) + local_args.update(config) + weight = self._extract_weight(data) + + # Bookkeeping in the container class + mask = local_args.get("mask", torch.ones_like(weight)) + param_class = local_args.get("parametrization", utils.FakeSparsity) + + if name in self.state: + # If the named data already exists - replace + warnings.warn( + "Replacing existing data of the same name. - Did you mean a different name?" + ) + + # reuse old config + old_args = self.data_groups[name] + local_args = copy.deepcopy(old_args) + local_args.update(config) + + if reuse_mask: + current_data = self.get_data(name=name) + assert weight.shape == current_data.shape, ( + "to retain the old mask, the shape of the new data must be the same as the previous one" + ) + mask = self.get_mask( + name=name + ) # reuse mask instead of creating a new one + + self._delete_data(name=name) + + # parameter creates a deepcopy of the weight inside, so create a buffer + self._container.register_buffer(name=name, tensor=weight) + parametrize.register_parametrization(self._container, name, param_class(mask)) + self.state[name]["mask"] = mask + self.data_groups[name] = local_args + return getattr(self._container, name) + + def get_data(self, name: str, return_original: bool = True): + r"""Returns weight tensor (or data) + Args: + - name: name of the data to be returned + - return_original returns weight tensor without applying parametrization if True + else - returns the sparsified version (parametrized) + """ + if name not in self.data_groups: + raise ValueError("data with specified name does not exist") + + if return_original: + if not parametrize.is_parametrized(self._container, name): + raise ValueError("mask squashed - original mask value does not exist") + data = getattr(self._container.parametrizations, name).original + return data + else: + return getattr(self._container, name) + + def _convert_mask(self, states, sparse_coo=True): + r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument.""" + states = copy.deepcopy(states) + for state in states.values(): + if sparse_coo: + state["mask"] = state["mask"].to_sparse_coo() + else: + state["mask"] = state["mask"].to_dense() + + return states + + def state_dict(self): + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - contains name -> mask mapping. + * data_groups - a list containing all sparsity configuration groups + with the key name specifying the name of the data + * container_state_dict - the state dictionary of the internal + container model used for sparsification + """ + state = self._convert_mask(self.state) + return { + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def _load_container_from_state(self, states, data_groups, container_state_dict): + r"""This restores the state of the container specifically based on the data present in state and data_groups + If the data was parametrized, then the data would be added to the container and then parametrized, + else it would just add the attribute the container. + """ + for name, state in states.items(): + config_name = data_groups.get(name, None) + if config_name is None: + raise RuntimeError(f"Error loading {name}") + + # check if the data with such a name was parametrized, if so parametrize + # otherwise just set the attribute and continue + parametrized_name = f"parametrizations.{name}.original" + parametrized = False + data = container_state_dict.get(name, None) + if name in container_state_dict: + # the parametrization was probably removed for this + data = container_state_dict.get(name) + + elif parametrized_name in container_state_dict: + # so the weight was parametrized + data = container_state_dict.get(parametrized_name) + parametrized = True + + else: + raise RuntimeError(f"Error loading {name}") + + self._container.register_buffer(name=name, tensor=data) + + if parametrized: + # register parameter if parametrized + mask = state.get("mask", torch.ones_like(data)) + param_class = data_groups.get( + "parametrization", utils.FakeSparsity + ) # change once public_api for utils is fixed! + parametrize.register_parametrization( + self._container, name, param_class(mask) + ) + + def load_state_dict(self, state_dict, strict=True): + r"""The load_state_dict() restores the state of the sparsifier based on the state_dict + + Args: + * state_dict - the dictionary that to which the current sparsifier needs to be restored to + * strict - If True - the sparsifier is reset and is restored exactly to the state in state_dict. + If False - the current sparsifier is not reset before loading the state_dict i.e. data added + before loading the state_dict is not erased. + """ + states = copy.deepcopy(state_dict["state"]) + data_groups = copy.deepcopy(state_dict["data_groups"]) + container_state_dict = copy.deepcopy(state_dict["_container"]) + + states = self._convert_mask( + states, sparse_coo=False + ) # convert sparse coo mask to dense + if strict: + # if strict load -> then reset container + self._container = _Container() + + self._load_container_from_state(states, data_groups, container_state_dict) + + if not strict: + states.update(self.state) + data_groups.update(self.data_groups) + + self.__setstate__({"state": states, "data_groups": data_groups}) + + def __setstate__(self, state): + if "_container" in state: # If container object is in state then load model + container_dict = state.pop("_container") + self._container = _Container() + state["state"] = self._convert_mask( + state["state"], sparse_coo=False + ) # convert sparse coo mask to dense + self._load_container_from_state( + state["state"], state["data_groups"], container_dict + ) + + self.__dict__.update(state) + + def __getstate__(self): + state = self._convert_mask(self.state) + return { + "defaults": self.defaults, + "state": state, + "data_groups": self.data_groups, + "_container": self._container.state_dict(), + } + + def __repr__(self): # type:ignore[override] + format_string = self.__class__.__name__ + " (" + for name, sparse_args in self.data_groups.items(): + format_string += "\n" + format_string += "\tData Group\n" + format_string += f"\t name: {name}\n" + for key in sorted(sparse_args.keys()): + if key == "data": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def get_mask(self, name: str): + if name not in self.state: + raise ValueError("data with specified name does not exist") + return self.state[name]["mask"] + + def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): + r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings + to squash mask for. If none, squashes mask for all the keys + kwargs: + * names: list of strings to squash mask for + * sparsified: if true - applies the mask before squashing + if false - does not apply the mask before squashing + """ + if names is None: + names = list(self.data_groups.keys()) + for name in names: + parametrize.remove_parametrizations( + self._container, name, leave_parametrized=leave_parametrized + ) + + def step(self): # type:ignore[override] + if not self.enable_mask_update: + return + with torch.no_grad(): + for name, config in self.data_groups.items(): + # get non-sparsified data + data = self.get_data(name) + # need name for the mask otherwise can directly pass mask? + self.update_mask(name, data, **config) + + @abc.abstractmethod + def update_mask(self, name, data, **kwargs): # type: ignore[override] + pass + + def _delete_data(self, name): + """Detaches some data from the sparsifier. + + Args: + name (str) + Name of the data to be removed from the sparsifier + + Note: + Currently private. Kind of used as a helper function when replacing data of the same name + """ + self.squash_mask( + names=[name], leave_parametrized=False + ) # do not apply the mask while deleting + delattr(self._container, name) + self.state.pop(name) + self.data_groups.pop(name) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4b4f913f5033081fc34c4f6b6057da25b93485 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -0,0 +1,203 @@ +# mypy: allow-untyped-defs +import operator +from functools import reduce +from typing import Any, Optional + +import torch +from torch.nn import functional as F + +from .base_data_sparsifier import BaseDataSparsifier + + +__all__ = ["DataNormSparsifier"] + + +class DataNormSparsifier(BaseDataSparsifier): + r"""L1-Norm Sparsifier + This sparsifier computes the *L1-norm* of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + Args: + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block + zeros_per_block: Number of zeros in a sparse block + Note:: + All arguments to the DataNormSparsifier constructor are "default" + arguments and could be overridden by the configuration provided in the + `add_data` step. + """ + + def __init__( + self, + data_list: Optional[list[tuple[str, Any]]] = None, + sparsity_level: float = 0.5, + sparse_block_shape: tuple[int, int] = (1, 4), + zeros_per_block: Optional[int] = None, + norm: str = "L1", + ): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + + assert norm in ["L1", "L2"], "only L1 and L2 norm supported at the moment" + + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + self.norm = norm + super().__init__(data_list=data_list, **defaults) + + def __get_scatter_folded_mask( + self, data, dim, indices, output_size, sparse_block_shape + ): + mask = torch.ones_like(data) + mask.scatter_(dim=dim, index=indices, value=0) # zeroing out + mask = F.fold( + mask, + output_size=output_size, + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + mask = mask.to(torch.int8) + return mask + + def __get_block_level_mask(self, data, sparse_block_shape, zeros_per_block): + # Assume data is a squeezed tensor + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + values_per_block = block_height * block_width + + # just return zeros if zeroing all elements in block + if values_per_block == zeros_per_block: + return torch.zeros_like(data, dtype=torch.int8) + + # creating additional height and width to support padding + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones( + height + dh, width + dw, dtype=data.dtype, device=data.device + ) + padded_data = ( + padded_data * torch.nan + ) # can also be replaced with 0 to stop the removal of edge data + padded_data[0:height, 0:width] = data + unfolded_data = F.unfold( + padded_data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ) + + _, sorted_idx = torch.sort(unfolded_data, dim=1) + sorted_idx = sorted_idx[ + :, :zeros_per_block, : + ] # zero out zeros_per_block number of elements + + mask = self.__get_scatter_folded_mask( + data=unfolded_data, + dim=1, + indices=sorted_idx, + output_size=padded_data.shape, + sparse_block_shape=sparse_block_shape, + ) + + mask = ( + mask.squeeze(0).squeeze(0)[:height, :width].contiguous() + ) # remove padding and make contiguous + return mask + + def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): + height, width = data.shape[-2], data.shape[-1] + block_height, block_width = sparse_block_shape + dh = (block_height - height % block_height) % block_height + dw = (block_width - width % block_width) % block_width + + data_norm = F.avg_pool2d( + data[None, None, :], + kernel_size=sparse_block_shape, + stride=sparse_block_shape, + ceil_mode=True, + ) + + values_per_block = reduce(operator.mul, sparse_block_shape) + + data_norm = data_norm.flatten() + num_blocks = len(data_norm) + + data_norm = data_norm.repeat( + 1, values_per_block, 1 + ) # get similar shape after unfold + _, sorted_idx = torch.sort(data_norm, dim=2) + + threshold_idx = round(sparsity_level * num_blocks) # number of blocks to remove + sorted_idx = sorted_idx[:, :, :threshold_idx] + + mask = self.__get_scatter_folded_mask( + data=data_norm, + dim=2, + indices=sorted_idx, + output_size=(height + dh, width + dw), + sparse_block_shape=sparse_block_shape, + ) + + mask = mask.squeeze(0).squeeze(0)[ + :height, :width + ] # squeeze only the first 2 dimension + return mask + + def update_mask( # type: ignore[override] + self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs + ): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than " + "the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + if self.norm == "L1": + data_norm = torch.abs(data).squeeze() # absolute value based (L1) + else: + data_norm = (data * data).squeeze() # square every element for L2 + + if len(data_norm.shape) > 2: # only supports 2 dimensional data at the moment + raise ValueError("only supports 2-D at the moment") + + elif len(data_norm.shape) == 1: # in case the data is bias (or 1D) + data_norm = data_norm[None, :] + + mask = self.get_mask(name) + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + + # Fetch the high level mask that zeros out entire blocks + data_lvl_mask = self.__get_data_level_mask( + data=data_norm, + sparsity_level=sparsity_level, + sparse_block_shape=sparse_block_shape, + ) + + # Fetch block level mask that zeros out 'zeros_per_block' number of elements in every block + block_lvl_mask = self.__get_block_level_mask( + data=data_norm, + sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block, + ) + + # zero out the entries inside those blocks whose block is sparsified + mask.data = torch.where(data_lvl_mask == 1, data_lvl_mask, block_lvl_mask) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50d5684961bc807d5ae1b02615ade168416c9b3d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import logging + +from torch.ao.pruning._experimental.data_sparsifier.base_data_sparsifier import ( + SUPPORTED_TYPES, +) + + +logger: logging.Logger = logging.getLogger(__name__) + + +def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None): + """Attaches a data sparsifier to all the layers of the module. + Essentially, loop over all the weight parameters in the module and + attach it to the data sparsifier. + Note:: + The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below) + before attaching to the sparsifier. This is because, the data + sparsifier uses a dummy model inside to store the weight parameters. + """ + if config is None: + config = {} + for name, parameter in module.named_parameters(): + if type(parameter) in SUPPORTED_TYPES: + valid_name = _get_valid_name(name) + # will be defaulted to default configs + data_sparsifier.add_data( + name=valid_name, data=parameter, **config.get(valid_name, {}) + ) + + +def _get_valid_name(name): + return name.replace(".", "_") # . is not allowed as a name + + +def _log_sparsified_level(model, data_sparsifier) -> None: + # Show the level of sparsity AFTER step: + for name, parameter in model.named_parameters(): + if type(parameter) not in SUPPORTED_TYPES: + continue + valid_name = _get_valid_name(name) + mask = data_sparsifier.get_mask(name=valid_name) + sparsity_level = 1.0 - mask.float().mean() + logger.info("Sparsity in layer %s = % .2%", name, sparsity_level) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py new file mode 100644 index 0000000000000000000000000000000000000000..00e9b1cab6c3ceb55d8a053e6db06014fa4f30c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +from collections import defaultdict +from copy import deepcopy +from typing import Any, Optional, TYPE_CHECKING + +import pytorch_lightning as pl # type: ignore[import] + +from ._data_sparstity_utils import ( + _attach_model_to_data_sparsifier, + _get_valid_name, + _log_sparsified_level, +) + + +if TYPE_CHECKING: + import torch + + +class PostTrainingDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables post-training sparsity. + + This callback aims to sparsify the model inside lightning module after training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + once the training completes. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + Hooks implemented: + on_fit_end() + 1. copies the model and attaches it to the sparsifier + 2. sparsier step() is called + 3. squashes the mask() + """ + + def __init__(self, data_sparsifier_class, data_sparsifier_args): + super().__init__() + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + self.data_sparsifier: Any = None + self.sparsified: Optional[torch.nn.Module] = None + + def on_fit_end(self, trainer, pl_module) -> None: + self.sparsified = deepcopy(pl_module.model).eval() + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + + _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier) + + self.data_sparsifier.step() + + self.data_sparsifier.squash_mask() # currently squashes params for all mask + + _log_sparsified_level(self.sparsified, self.data_sparsifier) + + +class TrainingAwareDataSparsity(pl.callbacks.Callback): + """Lightning callback that enables in-training sparsity. + + This callback aims to sparsify the model inside lightning module during training. + **Note that the model is copied and then sparsified, so the existing model is not modified** + + The sparsified model can be used for comparison and can be accessed using + .sparsified + + Args: + data_sparsifier_class (some implemented class of BaseDataSparsifier) + The data sparsifier object of this class is created when the + training starts. + Note: Objects should not be passed in here as they are created + when the training starts. + + data_sparsifier_args (Dict) + Dictionary of args to be passed to the data sparsifier. + Note: data_list arg should be ignored + + data_scheduler_class (some implemented class of BaseDataScheduler) + The data scheduler of this class is created when the training starts + Note: Objects should not be passed in here as they are created + when the training starts. + + data_scheduler_args(Dict) + Dictionary of args to be passed to the data scheduler. + **Note: data_sparsifier arg should be ignored as the recipe + creates and pass sparsifier object into the class** + + Hooks implemented: + on_train_start() + Data sparsifier and scheduler objects are created. + Pytorch model attached to the sparsifier + + on_train_epoch_start() + Loads the state_dict of the data sparsifier + + on_train_epoch_end() + 1. Copies the model and attaches it to the sparsifier + 2. sparsifier step() and scheduler step() + 3. Dump state_dict of the current sparsifier + + on_train_end() + squash mask + """ + + def __init__( + self, + data_sparsifier_class, + data_sparsifier_args, + data_scheduler_class, + data_scheduler_args, + ): + super().__init__() + # data sparsifier objects + self.data_sparsifier_class = data_sparsifier_class + self.data_sparsifier_args = data_sparsifier_args + + # scheduler objects + self.data_scheduler_class = data_scheduler_class + self.data_scheduler_args = data_scheduler_args + + # fields + self.data_sparsifier: Any = None + self.data_scheduler: Any = None + self.sparsified: Optional[torch.nn.Module] = None + + self.data_sparsifier_state_dict: Any = None + + def on_train_start(self, trainer, pl_module) -> None: + # create sparsifier + self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args) + self.sparsified = deepcopy(pl_module.model) + + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier + ) # just to populate the base_sl in the scheduler + + # create scheduler + args = deepcopy(self.data_scheduler_args) + args["data_sparsifier"] = self.data_sparsifier + self.data_scheduler = self.data_scheduler_class(**args) + + def on_train_epoch_start(self, trainer, pl_module): + if self.data_sparsifier_state_dict is None: + return # probably first epoch + + # load the existing config for each data + self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict) + + def __create_config_based_on_state(self, pl_module): + config: dict = defaultdict() + if self.data_sparsifier_state_dict is None: + return config + for name, _ in pl_module.model.named_parameters(): + valid_name = _get_valid_name(name) + config[valid_name] = self.data_sparsifier.data_groups[valid_name] + + return config + + def on_train_epoch_end(self, trainer, pl_module): + self.sparsified = deepcopy(pl_module.model) + config = self.__create_config_based_on_state(pl_module) + + # attach model to the data sparsifier + _attach_model_to_data_sparsifier( + self.sparsified, self.data_sparsifier, config=config + ) + self.data_sparsifier.step() + self.data_scheduler.step() + + self.data_sparsifier_state_dict = self.data_sparsifier.state_dict() + + def on_train_end(self, trainer, pl_module): + self.data_sparsifier.squash_mask() diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b2943e2af1a872edc56e95452f2b0610f1fb0007 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.nn as nn +from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn + + +SUPPORTED_MODULES = {nn.Embedding, nn.EmbeddingBag} + + +def _fetch_all_embeddings(model): + """Fetches Embedding and EmbeddingBag modules from the model""" + embedding_modules = [] + stack = [model] + while stack: + module = stack.pop() + for _, child in module.named_children(): + fqn_name = module_to_fqn(model, child) + if type(child) in SUPPORTED_MODULES: + embedding_modules.append((fqn_name, child)) + else: + stack.append(child) + return embedding_modules + + +def post_training_sparse_quantize( + model, + data_sparsifier_class, + sparsify_first=True, + select_embeddings: Optional[list[nn.Module]] = None, + **sparse_config, +): + """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. + The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. + + Args: + - model (nn.Module) + model whose embeddings needs to be sparsified + - data_sparsifier_class (type of data sparsifier) + Type of sparsification that needs to be applied to model + - sparsify_first (bool) + if true, sparsifies first and then quantizes + otherwise, quantizes first and then sparsifies. + - select_embeddings (List of Embedding modules) + List of embedding modules to in the model to be sparsified & quantized. + If None, all embedding modules with be sparsified + - sparse_config (Dict) + config that will be passed to the constructor of data sparsifier object. + + Note: + 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. + - before sparsifying, the embedding layers are dequantized. + - scales and zero-points are saved + - embedding layers are sparsified and `squash_mask` is applied + - embedding weights are requantized using the saved scales and zero-points + 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. + - embeddings are sparsified first + - quantization is applied on the sparsified embeddings + """ + data_sparsifier = data_sparsifier_class(**sparse_config) + + # if select_embeddings is None, perform it on all embeddings + if select_embeddings is None: + embedding_modules = _fetch_all_embeddings(model) + + else: + embedding_modules = [] + assert isinstance(select_embeddings, list), ( + "the embedding_modules must be a list of embedding modules" + ) + for emb in select_embeddings: + assert type(emb) in SUPPORTED_MODULES, ( + "the embedding_modules list must be an embedding or embedding bags" + ) + fqn_name = module_to_fqn(model, emb) + assert fqn_name is not None, ( + "the embedding modules must be part of input model" + ) + embedding_modules.append((fqn_name, emb)) + + if sparsify_first: + # sparsify + for name, emb_module in embedding_modules: + valid_name = name.replace(".", "_") + data_sparsifier.add_data(name=valid_name, data=emb_module) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + else: + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + # retrieve scale & zero_points + quantize_params: dict[str, dict] = { + "scales": {}, + "zero_points": {}, + "dequant_weights": {}, + "axis": {}, + "dtype": {}, + } + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + assert quantized_emb is not None # satisfy mypy + + quantized_weight = quantized_emb.weight() # type: ignore[operator] + quantize_params["scales"][name] = quantized_weight.q_per_channel_scales() + quantize_params["zero_points"][name] = ( + quantized_weight.q_per_channel_zero_points() + ) + quantize_params["dequant_weights"][name] = torch.dequantize( + quantized_weight + ) + quantize_params["axis"][name] = quantized_weight.q_per_channel_axis() + quantize_params["dtype"][name] = quantized_weight.dtype + + # attach data to sparsifier + data_sparsifier.add_data( + name=name.replace(".", "_"), + data=quantize_params["dequant_weights"][name], + ) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + assert quantized_emb is not None # satisfy mypy + requantized_vector = torch.quantize_per_channel( + quantize_params["dequant_weights"][name], + scales=quantize_params["scales"][name], + zero_points=quantize_params["zero_points"][name], + dtype=quantize_params["dtype"][name], + axis=quantize_params["axis"][name], + ) + + quantized_emb.set_weight(requantized_vector) # type: ignore[operator] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..3da27ba38df55b6ec738ae682aac5cbfc4da731d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -0,0 +1,97 @@ +# mypy: allow-untyped-defs +from typing import Callable, Optional, Union + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier + + +__all__ = ["FPGMPruner"] + + +class FPGMPruner(BaseStructuredSparsifier): + r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner + This sparsifier prune fliter (row) in a tensor according to distances among filters according to + `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of filters (rows) that are zeroed-out. + 2. `dist` defines the distance measurement type. Default: 3 (L2 distance). + Available options are: [1, 2, (custom callable distance function)]. + + Note:: + Inputs should be a 4D convolutional tensor of shape (N, C, H, W). + - N: output channels size + - C: input channels size + - H: height of kernel + - W: width of kernel + """ + + def __init__( + self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None + ): + defaults = { + "sparsity_level": sparsity_level, + } + + if dist is None: + dist = 2 + + if callable(dist): + self.dist_fn = dist + elif dist == 1: + self.dist_fn = lambda x: torch.cdist(x, x, p=1) + elif dist == 2: + self.dist_fn = lambda x: torch.cdist(x, x, p=2) + else: + raise NotImplementedError("Distance function is not yet implemented.") + super().__init__(defaults=defaults) + + def _compute_distance(self, t): + r"""Compute distance across all entries in tensor `t` along all dimension + except for the one identified by dim. + Args: + t (torch.Tensor): tensor representing the parameter to prune + Returns: + distance (torch.Tensor): distance computed across filtters + """ + dim = 0 # prune filter (row) + + size = t.size(dim) + slc = [slice(None)] * t.dim() + + # flatten the tensor along the dimension + t_flatten = [ + t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1) + for i in range(size) + ] + t_flatten = torch.stack(t_flatten) + + # distance measurement + dist_matrix = self.dist_fn(t_flatten) + + # more similar with other filter indicates large in the sum of row + distance = torch.sum(torch.abs(dist_matrix), 1) + + return distance + + def update_mask( # type: ignore[override] + self, module, tensor_name, sparsity_level, **kwargs + ): + tensor_weight = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + if sparsity_level <= 0: + mask.data = torch.ones_like(mask).bool() + elif sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask).bool() + else: + distance = self._compute_distance(tensor_weight) + + tensor_size = tensor_weight.shape[0] # prune filter (row) + nparams_toprune = round(sparsity_level * tensor_size) + nparams_toprune = min( + max(nparams_toprune, 0), tensor_size + ) # clamp to [0, tensor_size] + topk = torch.topk(distance, k=nparams_toprune, largest=False) + mask[topk.indices] = False diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a57db6a8d8cde9a89c7cbda4dff6f6075559b59b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/__init__.py @@ -0,0 +1,5 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier +from .FPGM_pruner import FPGMPruner +from .lstm_saliency_pruner import LSTMSaliencyPruner +from .parametrization import BiasHook, FakeStructuredSparsity +from .saliency_pruner import SaliencyPruner diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..ffbb99bb2967e10a221578718e146c55131629c2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -0,0 +1,312 @@ +# mypy: allow-untyped-defs +from itertools import chain +from operator import getitem +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier +from torch.fx import symbolic_trace +from torch.nn.utils import parametrize + +from .match_utils import apply_match, MatchAllNode +from .parametrization import BiasHook, FakeStructuredSparsity, module_contains_param +from .prune_functions import ( + prune_conv2d, + prune_conv2d_activation_conv2d, + prune_conv2d_activation_pool_conv2d, + prune_conv2d_conv2d, + prune_conv2d_pool_activation_conv2d, + prune_conv2d_pool_flatten_linear, + prune_linear, + prune_linear_activation_linear, + prune_linear_linear, + prune_lstm_output_layernorm_linear, + prune_lstm_output_linear, +) + + +def _get_supported_structured_pruning_modules(): + SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given + nn.Linear, + nn.Conv2d, + nn.LSTM, + } + return SUPPORTED_STRUCTURED_PRUNING_MODULES + + +def _get_supported_activation_functions(): + SUPPORTED_ACTIVATION_FUNCTIONS = { + F.relu, + F.rrelu, + F.hardtanh, + F.relu6, + F.sigmoid, + F.hardsigmoid, + F.tanh, + F.silu, + F.mish, + F.hardswish, + F.elu, + F.celu, + F.selu, + F.hardshrink, + F.leaky_relu, + F.logsigmoid, + F.softplus, + F.prelu, + F.softsign, + F.tanhshrink, + F.gelu, + } + return SUPPORTED_ACTIVATION_FUNCTIONS + + +def _get_supported_activation_modules(): + SUPPORTED_ACTIVATION_MODULES = { + nn.ReLU, + nn.RReLU, + nn.Hardtanh, + nn.ReLU6, + nn.Sigmoid, + nn.Hardsigmoid, + nn.Tanh, + nn.SiLU, + nn.Mish, + nn.Hardswish, + nn.ELU, + nn.CELU, + nn.SELU, + nn.Hardshrink, + nn.LeakyReLU, + nn.LogSigmoid, + nn.Softplus, + nn.PReLU, + nn.Softsign, + nn.Tanhshrink, + nn.GELU, + } + return SUPPORTED_ACTIVATION_MODULES + + +def _get_default_structured_pruning_patterns() -> dict[ + tuple[Union[type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], +]: + """ + Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. + """ + patterns: dict[ + tuple[Union[type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], + ] = { + # linear -> linear + (nn.Linear, "output"): prune_linear, + (nn.Linear, nn.Linear): prune_linear_linear, + # conv2d -> conv2d + (nn.Conv2d, "output"): prune_conv2d, + (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, + # TODO LSTM Structured pruning does not support returned state currently. + # Should find a way to explicitly match getitem(0) instead of getitem. + # This will also require changing the pruning function. + # lstm -> getitem(0) -> linear + (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, + # lstm -> getitem(0) -> layernorm -> linear + (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, + } + + for activation in chain( + _get_supported_activation_functions(), _get_supported_activation_modules() + ): + patterns.update( + { + # linear -> activation -> linear + (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, + # conv2d -> activation -> conv2d + (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, + # conv2d -> activation -> pool -> conv2d + ( + nn.Conv2d, + activation, + nn.AvgPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.avg_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + nn.MaxPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.max_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + # conv2d -> pool -> activation -> conv2d + ( + nn.Conv2d, + nn.AvgPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.avg_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + nn.MaxPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.max_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + # conv2d -> adaptive pool -> flatten -> linear + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + } + ) + return patterns + + +class BaseStructuredSparsifier(BaseSparsifier): + r"""Base class for structured pruning. + + Abstract methods that need to be implemented: + - update_mask: Function to compute a new mask for all keys in the + `groups` attribute. + + Args: + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + """ + + def __init__(self, defaults, patterns=None): + super().__init__(defaults) + if patterns is None: + patterns = _get_default_structured_pruning_patterns() + self.patterns = patterns + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Optional[set[type]] = None, + ) -> None: + if SUPPORTED_MODULES is None: + SUPPORTED_MODULES = _get_supported_structured_pruning_modules() + super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) + + def _prepare(self, *args, **kwargs) -> None: + r"""This function will attach the FakeStructuredSparsity parameterizations + and BiasHooks at the appropriate points in the model. + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeStructuredSparsity) + tensor = getattr(module, tensor_name) + + mask = config.get( + "mask", + torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), + ) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + # if linear / conv, we add in bias hooks + if isinstance(module, (nn.Linear, nn.Conv2d)): + prune_bias = config.get("prune_bias", True) + if module.bias is not None: + module.register_parameter( + "_bias", nn.Parameter(module.bias.detach()) + ) + module.bias = None + module.prune_bias = prune_bias + + module.register_forward_hook( + BiasHook(module.parametrizations.weight[0], prune_bias) # type: ignore[union-attr, index] + ) + + def prune(self) -> None: + r""" + This function will FX symbolically trace the model and then find instances of the patterns + defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). + + For each pattern, it will apply to corresponding conversion function, which will modify the output + and input size expected by the modules within the pattern + """ + + self.traced = symbolic_trace(self.model) + modules = dict(self.traced.named_modules()) + + # Right now we check for matches simply by iterating across all the patterns + # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup + for node in self.traced.graph.nodes: + for pattern, convert_fn in self.patterns.items(): + matched = apply_match(modules, pattern, node, []) + if matched is None: + continue + + first_module = modules.get(node.target) + # check if first module exists and has appropriate parameterization, otherwise skip + if ( + first_module is not None + and parametrize.is_parametrized(first_module) + and module_contains_param(first_module, FakeStructuredSparsity) + ): + convert_block = [] + for node in matched: + if node.op == "call_module": + convert_block.append(modules.get(node.target)) + elif node.op == "call_function": + convert_block.append(node.target) + convert_fn(*convert_block) + + for module in self.traced.modules(): + if module_contains_param(module, FakeStructuredSparsity): + raise Exception( # noqa: TRY002 + f"Error: {module} still contains FakeStructuredSparsity parametrizations!" + ) + + self.traced.graph.lint() + self.traced.recompile() + return self.traced # type: ignore[return-value] diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..e8acbc5e458c65a83bd4d05608d46b43cdb94722 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs +from typing import cast + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity + + +class LSTMSaliencyPruner(BaseStructuredSparsifier): + """ + Prune packed LSTM weights based on saliency. + For each layer {k} inside a LSTM, we have two packed weight matrices + - weight_ih_l{k} + - weight_hh_l{k} + + These tensors pack the weights for the 4 linear layers together for efficiency. + + [W_ii | W_if | W_ig | W_io] + + Pruning this tensor directly will lead to weights being misassigned when unpacked. + To ensure that each packed linear layer is pruned the same amount: + 1. We split the packed weight into the 4 constituent linear parts + 2. Update the mask for each individual piece using saliency individually + + This applies to both weight_ih_l{k} and weight_hh_l{k}. + """ + + def update_mask(self, module, tensor_name, **kwargs): + weights = getattr(module, tensor_name) + + for p in getattr(module.parametrizations, tensor_name): + if isinstance(p, FakeStructuredSparsity): + mask = cast(torch.Tensor, p.mask) + + # select weights based on magnitude + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + # take norm over all but first dim + dims = tuple(range(1, weights.dim())) + saliency = weights.norm(dim=dims, p=1) + + # handle weights in 4 groups + split_size = len(mask) // 4 + masks = torch.split(mask, split_size) + saliencies = torch.split(saliency, split_size) + + for keep_mask, sal in zip(masks, saliencies): + # mask smallest k values to be removed + k = int(len(keep_mask) * kwargs["sparsity_level"]) + prune = sal.topk(k, largest=False, sorted=False).indices + keep_mask.data[prune] = False # modifies underlying p.mask directly diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64ef6d78c58c7887a2799182fbc904dfcde39b50 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/match_utils.py @@ -0,0 +1,65 @@ +""" +Contains utility functions to check if a pattern is in the graph and return the matching nodes +""" + +from typing import Any, Optional, Union + +import torch +from torch import nn +from torch.ao.quantization.utils import MatchAllNode +from torch.fx import Node +from torch.nn.utils import parametrize + + +def _match( + modules: dict[str, nn.ModuleDict], + node: Node, + current: Union[nn.Module, Any], +) -> bool: + r""" + checks to see if a single node of a pattern matches + """ + if isinstance(current, type) and issubclass(current, MatchAllNode): + return True + if not isinstance(node, Node): + return False + if isinstance(current, type) and issubclass(current, torch.nn.Module): + return ( + node.op == "call_module" + and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index] + == current + ) + elif callable(current): + return node.op == "call_function" and node.target is current + elif isinstance(current, str): + return node.target == current + return False + + +def apply_match( + modules: dict[str, nn.ModuleDict], + pattern: Union[tuple[Any], Any], + node: Node, + matched_node_pattern: list[Node], +) -> Optional[list[Node]]: + r""" + This function will return the matched nodes if the pattern matches the node given + If there is no match, it will return None + """ + if isinstance(pattern, tuple): + if len(pattern) == 1: + if _match(modules, node, pattern[0]): + return matched_node_pattern + [node] + + first, *rest = pattern + if _match(modules, node, first): + if rest is None: + return matched_node_pattern + [node] + + for user in node.users: + return apply_match( + modules, tuple(rest), user, matched_node_pattern + [node] + ) + elif _match(modules, node, pattern): + return [node] + return None diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py new file mode 100644 index 0000000000000000000000000000000000000000..58b3f7651caab971ff524c85e00d6448a77a932d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +import torch +from torch import nn +from torch.nn.utils.parametrize import is_parametrized + + +def module_contains_param(module, parametrization): + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() + ) + return False + + +# Structured Pruning Parameterizations +class FakeStructuredSparsity(nn.Module): + r""" + Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to + the 'weight' or any other parameter that requires a mask. + + Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + assert isinstance(self.mask, torch.Tensor) + assert self.mask.shape[0] == x.shape[0] + shape = [1] * len(x.shape) + shape[0] = -1 + return self.mask.reshape(shape) * x + + def state_dict(self, *args, **kwargs): + # avoid double saving masks + return {} + + +class BiasHook: + def __init__(self, parametrization, prune_bias): + self.param = parametrization + self.prune_bias = prune_bias + + def __call__(self, module, input, output): + if getattr(module, "_bias", None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[~self.param.mask] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a1882af4ca11cc156bb8924e791134fda3418be0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -0,0 +1,479 @@ +# mypy: allow-untyped-defs +""" +Collection of conversion functions for linear / conv2d structured pruning +Also contains utilities for bias propagation +""" + +from typing import Callable, cast, Optional + +import torch +from torch import nn, Tensor +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import ParametrizationList + +from .parametrization import BiasHook, FakeStructuredSparsity + + +# BIAS PROPAGATION +def _remove_bias_handles(module: nn.Module) -> None: + if hasattr(module, "_forward_hooks"): + bias_hooks: list[int] = [] + for key, hook in module._forward_hooks.items(): + if isinstance(hook, BiasHook): + bias_hooks.append(key) + + for key in bias_hooks: + del module._forward_hooks[key] + + +def _get_adjusted_next_layer_bias( + next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor +) -> nn.Parameter: + r"""Returns new adjusted bias for the second supported module""" + if parametrize.is_parametrized(next_layer): + # need to access original weight + parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + next_weight = weight_parameterizations.original + else: + next_weight = cast(Tensor, next_layer.weight) + + scaling_weight = next_weight[:, ~mask] + if isinstance(next_layer, nn.Conv2d): # checking for Conv2d + # Propagating first layer pruned biases and calculating the new second layer bias + # involves more steps since the Conv2d scaling weight has extra dimensions, + # so adding bias involves broadcasting, logically: + # for each channel k in range(oC): + # scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T) + # new_next_bias[k] = old_next_bias[k] + scaled_biases + scaling_product = torch.matmul( + pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2) + ) + sum_range = list(range(len(scaling_product.shape)))[ + 1: + ] # all but the first dimension + scaled_biases = torch.sum(scaling_product, sum_range) + elif isinstance(next_layer, nn.Linear): # Linear + scaled_biases = torch.matmul( + pruned_biases, torch.transpose(scaling_weight, 0, 1) + ) # recall b2_new = b1 @ w2.T + b2 + else: + raise NotImplementedError(f"Type {type(next_layer)} not supported yet.") + + if ( + parametrize.is_parametrized(next_layer) + and getattr(next_layer, "_bias", None) is not None + ): # next_layer is parametrized & has original bias ._bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias) # type: ignore[operator] + elif ( + not parametrize.is_parametrized(next_layer) and next_layer.bias is not None + ): # next_layer not parametrized & has .bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias) # type: ignore[operator] + else: # next_layer has no bias + adjusted_bias = nn.Parameter(scaled_biases) + return adjusted_bias + + +def _prune_module_bias(module: nn.Module, mask: Tensor) -> None: + r"""Applies mask to given modules bias""" + # prune bias along with weights, discard pruned indices of bias + original_bias = cast(Tensor, getattr(module, "_bias", module.bias)) + if original_bias is not None: + module.bias = nn.Parameter(original_bias[mask]) + + # remove _bias parameter + if hasattr(module, "_bias"): + delattr(module, "_bias") + + +def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]: + r""" + In the case that we need to propagate biases, this function will return the biases we need + """ + # set current module bias + if module.bias is not None: + module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) + elif getattr(module, "_bias", None) is not None: + module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) + + # get pruned biases to propagate to subsequent layer + if getattr(module, "_bias", None) is not None: + pruned_biases = cast(Tensor, module._bias)[~mask] + else: + pruned_biases = None + + if hasattr(module, "_bias"): + delattr(module, "_bias") + + return pruned_biases + + +# LINEAR +def _prune_linear_helper(linear: nn.Linear) -> Tensor: + # expects linear to be a parameterized linear module + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) + linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined] + linear.out_features = linear.weight.shape[0] + _remove_bias_handles(linear) + + return mask + + +def prune_linear(linear: nn.Linear) -> None: + mask = _prune_linear_helper(linear) + if getattr(linear, "prune_bias", False): + _prune_module_bias(linear, mask) + + +def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None: + prune_linear_activation_linear(linear1, None, linear2) + + +def prune_linear_activation_linear( + linear1: nn.Linear, + activation: Optional[Callable[[Tensor], Tensor]], + linear2: nn.Linear, +): + mask = _prune_linear_helper(linear1) + if getattr(linear1, "prune_bias", False): + _prune_module_bias(linear1, mask) + else: + pruned_biases = _propagate_module_bias(linear1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask) + + with torch.no_grad(): + if parametrize.is_parametrized(linear2): + parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + linear2.in_features = weight_parameterizations.original.shape[1] + else: + linear2.weight = nn.Parameter(linear2.weight[:, mask]) + linear2.in_features = linear2.weight.shape[1] + + +# CONV2D +def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: + parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) + conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined] + conv2d.out_channels = conv2d.weight.shape[0] + + _remove_bias_handles(conv2d) + return mask + + +def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True) + + if getattr(conv2d_1, "_bias", None) is not None: + if ( + conv2d_1.bias is not None + ): # conv2d_1 has original bias and bias propagated from previous layer + new_bias = torch.zeros(conv2d_1.bias.shape) + new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] + # adjusted bias that to keep in conv2d_1 + new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] + # pruned biases that are kept instead of propagated + conv2d_1.bias = nn.Parameter(new_bias) + else: # conv2d_1 has only original bias + conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias)) + else: + # no original bias, only propagated bias + if ( + conv2d_1.bias is not None + ): # conv2d_1 has bias propagated from previous layer + conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined] + + if hasattr(conv2d_1, "_bias"): + delattr(conv2d_1, "_bias") + + +def prune_conv2d(conv2d: nn.Conv2d) -> None: + mask = _prune_conv2d_helper(conv2d) + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + + +def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None: + prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2) + + +def prune_conv2d_activation_conv2d( + conv2d_1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + conv2d_2: nn.Conv2d, +): + r""" + Fusion Pattern for conv2d -> some activation module / function -> conv2d layers + """ + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + prune_bias = getattr(conv2d_1, "prune_bias", False) + if ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None) + ): + prune_conv2d_padded(conv2d_1) + else: + mask = _prune_conv2d_helper(conv2d_1) + if prune_bias: + _prune_module_bias(conv2d_1, mask) + else: + pruned_biases = _propagate_module_bias(conv2d_1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + conv2d_2.bias = _get_adjusted_next_layer_bias( + conv2d_2, pruned_biases, mask + ) + + if ( + not ( + hasattr(conv2d_2, "padding") + and cast(tuple[int], conv2d_2.padding) > (0, 0) + ) + or conv2d_1.bias is None + ): + with torch.no_grad(): + if parametrize.is_parametrized(conv2d_2): + parametrization_dict = cast( + nn.ModuleDict, conv2d_2.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + conv2d_2.in_channels = weight_parameterizations.original.shape[1] + else: + conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask]) + conv2d_2.in_channels = conv2d_2.weight.shape[1] + + +def prune_conv2d_pool_activation_conv2d( + c1: nn.Conv2d, + pool: nn.Module, + activation: Optional[Callable[[Tensor], Tensor]], + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_activation_pool_conv2d( + c1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + pool: nn.Module, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_pool_flatten_linear( + conv2d: nn.Conv2d, + pool: nn.Module, + flatten: Optional[Callable[[Tensor], Tensor]], + linear: nn.Linear, +) -> None: + mask = _prune_conv2d_helper(conv2d) + + # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer. + # we determine the flattening scale (h * w), and readjust `first_pruned_indices` + # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`, + # and `pruned_biases` (repeat each bias by h * w). + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + linear_ic = weight_parameterizations.original.shape[1] + else: + linear_ic = linear.weight.shape[1] + + conv2d_oc = len(mask) + assert linear_ic % conv2d_oc == 0, ( + f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + ) + + flatten_scale = linear_ic // conv2d_oc + flattened_mask = torch.tensor( + [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device + ).flatten() + + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + else: + pruned_biases = cast(Tensor, _propagate_module_bias(conv2d, mask)) + flattened_pruned_biases = torch.tensor( + [[bias] * flatten_scale for bias in pruned_biases], device=mask.device + ).flatten() + linear.bias = _get_adjusted_next_layer_bias( + linear, flattened_pruned_biases, flattened_mask + ) + + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, flattened_mask] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, flattened_mask]) + linear.in_features = linear.weight.shape[1] + + +def prune_lstm_output_linear( + lstm: nn.LSTM, getitem: Callable, linear: nn.Linear +) -> None: + prune_lstm_output_layernorm_linear(lstm, getitem, None, linear) + + +def prune_lstm_output_layernorm_linear( + lstm: nn.LSTM, + getitem: Callable, + layernorm: Optional[nn.LayerNorm], + linear: nn.Linear, +) -> None: + for i in range(lstm.num_layers): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_ih_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_ih_l{i}", leave_parametrized=True + ) + setattr( + lstm, + f"weight_ih_l{i}", + nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]), + ) + setattr( + lstm, + f"bias_ih_l{i}", + nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]), + ) + + if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_hh_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_hh_l{i}", leave_parametrized=True + ) + # splitting out hidden-hidden masks + W_hi, W_hf, W_hg, W_ho = torch.split( + getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size + ) + M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size) # type: ignore[arg-type] + + # resize each individual weight separately + W_hi = W_hi[M_hi][:, M_hi] + W_hf = W_hf[M_hf][:, M_hf] + W_hg = W_hg[M_hg][:, M_hg] + W_ho = W_ho[M_ho][:, M_ho] + + # concat, use this as new weight + new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho)) + setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight)) + setattr( + lstm, + f"bias_hh_l{i}", + nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]), + ) + + # If this is the final layer, then we need to prune linear layer columns + if i + 1 == lstm.num_layers: + lstm.hidden_size = int(M_hi.sum()) + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast( + nn.ModuleDict, linear.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, M_ho]) + linear.in_features = linear.weight.shape[1] + + # if layernorm module, prune weight and bias + if layernorm is not None: + layernorm.normalized_shape = (linear.in_features,) + layernorm.weight = nn.Parameter(layernorm.weight[M_ho]) + layernorm.bias = nn.Parameter(layernorm.bias[M_ho]) + + # otherwise need to prune the columns of the input of the next LSTM layer + else: + with torch.no_grad(): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i + 1}"): + parametrization_dict = cast( + nn.ModuleDict, lstm.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, + getattr(parametrization_dict, f"weight_ih_l{i + 1}"), + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + else: + next_layer_weight = getattr(lstm, f"weight_ih_l{i + 1}") + setattr( + lstm, + f"weight_ih_l{i + 1}", + nn.Parameter(next_layer_weight[:, M_ho]), + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..a295b4622cc2d64714d4dab969a8923a6014a55d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +from .base_structured_sparsifier import BaseStructuredSparsifier + + +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune rows based on the saliency (L1 norm) of each row. + + This pruner works on N-Dimensional weight tensors. + For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. + We expect that the resulting saliency vector has the same shape as our mask. + We then pick elements to remove until we reach the target sparsity_level. + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + if weights.dim() <= 1: + raise Exception( # noqa: TRY002 + "Structured pruning can only be applied to a 2+dim weight tensor!" + ) + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + assert saliency.shape == mask.shape + + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/_mappings.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc2c4f10aef5585072f36116282a2048965197a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/_mappings.py @@ -0,0 +1,23 @@ +# mypy: allow-untyped-defs +__all__ = [ + "get_static_sparse_quantized_mapping", + "get_dynamic_sparse_quantized_mapping", +] + + +def get_static_sparse_quantized_mapping(): + import torch.ao.nn.sparse + + _static_sparse_quantized_mapping = { + torch.nn.Linear: torch.ao.nn.sparse.quantized.Linear, + } + return _static_sparse_quantized_mapping + + +def get_dynamic_sparse_quantized_mapping(): + import torch.ao.nn.sparse + + _dynamic_sparse_quantized_mapping = { + torch.nn.Linear: torch.ao.nn.sparse.quantized.dynamic.Linear, + } + return _dynamic_sparse_quantized_mapping diff --git a/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d42ea803289c5864c0c669e6b3e8fef062246a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -0,0 +1,60 @@ +# mypy: allow-untyped-defs +import torch + +from . import base_sparsifier + + +class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): + r"""Nearly Diagonal Sparsifier + + This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. + Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. + An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. + 1 1 0 0 1 1 1 0 + 1 1 1 0 1 1 1 1 + 0 1 1 1 1 1 1 1 + 0 0 1 1 0 1 1 1 + Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated + + This sparsifier is controlled by one variable: + 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. + Currently - supports only odd number + + Note: + This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix + feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy + + Args: + nearliness: The degree of nearliness (default = 1) + + """ + + def __init__(self, nearliness: int = 1): + defaults = {"nearliness": nearliness} + super().__init__(defaults=defaults) + + def update_mask( # type:ignore[override] + self, module, tensor_name, nearliness, **kwargs + ): + mask = getattr(module.parametrizations, tensor_name)[0].mask + mask.data = torch.zeros_like(mask) + if nearliness <= 0: + return + + tensor = getattr(module, tensor_name) + height, width = tensor.shape + + if nearliness % 2 == 0: + raise ValueError("nearliness can only be an odd number") + dist_to_diagonal = nearliness // 2 + # check + if dist_to_diagonal >= min(height, width): + raise ValueError( + "nearliness cannot be larger than the dimensions of tensor." + ) + + for row in range(0, height): + # Bounds of entries that needs to be set to 1 + low = max(0, row - dist_to_diagonal) + high = min(width, row + dist_to_diagonal + 1) + mask[row, low:high].fill_(1) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/__init__.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc1792fd23faf3e914f195ed175619f43987ff2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/__init__.py @@ -0,0 +1,234 @@ +# mypy: allow-untyped-defs + +from typing import Callable, Optional, Union + +import torch +from torch import Tensor + +from .fake_quantize import * # noqa: F403 +from .fuse_modules import fuse_modules, fuse_modules_qat # noqa: F403 +from .fuser_method_mappings import * # noqa: F403 +from .observer import * # noqa: F403 +from .pt2e._numeric_debugger import ( # noqa: F401 + compare_results, + CUSTOM_KEY, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + prepare_for_propagation_comparison, +) +from .pt2e.export_utils import ( + _allow_exported_model_train_eval as allow_exported_model_train_eval, + _move_exported_model_to_eval as move_exported_model_to_eval, + _move_exported_model_to_train as move_exported_model_to_train, +) +from .qconfig import * # noqa: F403 +from .qconfig_mapping import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantization_mappings import * # noqa: F403 # type: ignore[no-redef] +from .quantize import * # noqa: F403 +from .quantize_jit import * # noqa: F403 +from .stubs import * # noqa: F403 + + +# ensure __module__ is set correctly for public APIs +ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] +ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +for _f in [ + compare_results, + extract_results_from_loggers, + generate_numeric_debug_handle, + prepare_for_propagation_comparison, +]: + _f.__module__ = "torch.ao.quantization" + +__all__ = [ + "DeQuantStub", + "FakeQuantize", + "FakeQuantizeBase", + "FixedQParamsFakeQuantize", + "FixedQParamsObserver", + "FusedMovingAvgObsFakeQuantize", + "HistogramObserver", + "MatchAllNode", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "ObserverOrFakeQuantize", + "Pattern", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "QConfig", + "QConfigAny", + "QConfigDynamic", + "QConfigMapping", + "QuantStub", + "QuantType", + "QuantWrapper", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "add_quant_dequant", + "convert", + "convert_dynamic_jit", + "convert_jit", + "default_affine_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_fake_quant", + "default_dynamic_quant_observer", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_eval_fn", + "default_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_fused_act_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "default_fused_wt_fake_quant", + "default_histogram_fake_quant", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_fake_quant", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_fake_quant", + "default_symmetric_fixed_qparams_observer", + "default_weight_fake_quant", + "default_weight_observer", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "fuse_conv_bn", + "fuse_conv_bn_jit", + "fuse_conv_bn_relu", + "fuse_convtranspose_bn", + "fuse_linear_bn", + "fuse_modules", + "fuse_modules_qat", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", + "fused_wt_fake_quant_range_neg_127_to_127", + "get_combined_dict", + "get_default_compare_output_module_list", + "get_default_custom_config_dict", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_float_to_quantized_operator_mappings", + "get_default_qat_module_mappings", + "get_default_qat_qconfig", + "get_default_qat_qconfig_dict", + "get_default_qat_qconfig_mapping", + "get_default_qconfig", + "get_default_qconfig_dict", + "get_default_qconfig_mapping", + "get_default_qconfig_propagation_list", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_dynamic_quant_module_class", + "get_embedding_qat_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_fuser_method", + "get_fuser_method_new", + "get_observer_state_dict", + "get_quantized_operator", + "get_static_quant_module_class", + "load_observer_state_dict", + "move_exported_model_to_eval", + "move_exported_model_to_train", + "allow_exported_model_train_eval", + "no_observer_set", + "per_channel_weight_observer_range_neg_127_to_127", + "prepare", + "prepare_dynamic_jit", + "prepare_jit", + "prepare_qat", + "propagate_qconfig_", + "qconfig_equals", + "quantize", + "quantize_dynamic", + "quantize_dynamic_jit", + "quantize_jit", + "quantize_qat", + "script_qconfig", + "script_qconfig_dict", + "swap_module", + "weight_observer_range_neg_127_to_127", + "generate_numeric_debug_handle", + "CUSTOM_KEY", + "NUMERIC_DEBUG_HANDLE_KEY", + "prepare_for_propagation_comparison", + "extract_results_from_loggers", + "compare_results", + # from torchao, should be merged with torchao + # in the future + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +def default_eval_fn(model, calib_data): + r"""Define the default evaluation function. + + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, _target in calib_data: + model(data) + + +class _DerivedObserverOrFakeQuantize(ObserverBase): + r"""This observer is used to describe an observer whose quantization parameters + are derived from other observers + """ + + def __init__( + self, + dtype: torch.dtype, + obs_or_fqs: list[ObserverOrFakeQuantize], + derive_qparams_fn: Callable[ + [list[ObserverOrFakeQuantize]], tuple[Tensor, Tensor] + ], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + qscheme: Optional[torch.qscheme] = None, + ch_axis: Optional[int] = None, + ): + super().__init__(dtype) + self.obs_or_fqs = obs_or_fqs + self.derive_qparams_fn = derive_qparams_fn + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + self.ch_axis = ch_axis + + from .utils import is_per_channel + + if is_per_channel(self.qscheme): + assert self.ch_axis is not None, ( + "Must provide a valid ch_axis if qscheme is per channel" + ) + + def forward(self, x: Tensor) -> Tensor: + return x + + def calculate_qparams(self): # type:ignore[override] + return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/_correct_bias.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/_correct_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..3f480486893d4ed6fe7e67bc36123ececd78a7dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/_correct_bias.py @@ -0,0 +1,156 @@ +# mypy: allow-untyped-defs +import torch +import torch.ao.nn.quantized as nnq +import torch.ao.ns._numeric_suite as ns +import torch.ao.quantization +import torch.nn as nn + + +__all__ = [ + "get_module", + "parent_child_names", + "get_param", + "MeanShadowLogger", + "bias_correction", +] + +_supported_modules = {nn.Linear, nn.Conv2d} +_supported_modules_quantized = {nnq.Linear, nnq.Conv2d} + + +def get_module(model, name): + """Given name of submodule, this function grabs the submodule from given model.""" + return dict(model.named_modules())[name] + + +def parent_child_names(name): + """Split full name of submodule into parent submodule's full name and submodule's name.""" + split_name = name.rsplit(".", 1) + if len(split_name) == 1: + return "", split_name[0] + else: + return split_name[0], split_name[1] + + +def get_param(module, attr): + """Get the parameter given a module and attribute. + + Sometimes the weights/bias attribute gives you the raw tensor, but sometimes + gives a function that will give you the raw tensor, this function takes care of that logic + """ + param = getattr(module, attr, None) + if callable(param): + return param() + else: + return param + + +class MeanShadowLogger(ns.Logger): + """Mean Logger for a Shadow module. + + A logger for a Shadow module whose purpose is to record the rolling mean + of the data passed to the floating point and quantized models + """ + + def __init__(self): + """Set up initial values for float and quantized stats, count, float sum, and quant sum.""" + super().__init__() + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + + def forward(self, x, y): # type: ignore[override] + """Compute the average of quantized and floating-point data from modules. + + The inputs x,y are output data from the quantized and floating-point modules. + x is for the quantized module, y is for the floating point module + """ + if x.is_quantized: + x = x.dequantize() + + self.count += 1 + if self.stats["quantized"] is None: + self.stats["quantized"] = x + self.quant_sum = x + else: + self.quant_sum += x + self.stats["quantized"] = self.quant_sum / self.count + + if self.stats["float"] is None: + self.stats["float"] = y + self.float_sum = y + else: + self.float_sum += y + self.stats["float"] = self.float_sum / self.count + + def clear(self): + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + + +def bias_correction( + float_model, + quantized_model, + img_data, + target_modules=_supported_modules_quantized, + neval_batches=None, +): + """Perform bias correction on a module. + + Using numeric suite shadow module, the expected output of the floating point and quantized modules + is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused + by quantization + Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2) + + Args: + float_model: a trained model that serves as a reference to what bias correction should aim for + quantized_model: quantized form of float_model that bias correction is to applied to + img_data: calibration data to estimate the expected output (used to find quantization error) + target_modules: specifies what submodules in quantized_model need bias correction (can be extended to + unquantized submodules) + neval_batches: a cap to the number of batches you want to be used for estimating the expected output + """ + ns.prepare_model_with_stubs( + float_model, quantized_model, _supported_modules, MeanShadowLogger + ) + + uncorrected_modules = { + name: submodule + for name, submodule in quantized_model.named_modules() + if type(submodule) in target_modules + } + + for uncorrected_module in uncorrected_modules: + quantized_submodule = get_module(quantized_model, uncorrected_module) + bias = get_param(quantized_submodule, "bias") + if bias is not None: + for count, data in enumerate(img_data, start=1): + quantized_model(data[0]) + if count == neval_batches: + break + ob_dict = ns.get_logger_dict(quantized_model) + parent_name, _ = parent_child_names(uncorrected_module) + + float_data = ob_dict[parent_name + ".stats"]["float"] + quant_data = ob_dict[parent_name + ".stats"]["quantized"] + + # math for expected_error + quantization_error = quant_data - float_data + dims = list(range(quantization_error.dim())) + # Note: we don't want to take the mean over the output channel dimension + dims.remove(1) + expected_error = torch.mean(quantization_error, dims) + + updated_bias = bias.data - expected_error + + bias.data = updated_bias + + # Resets the data contained in the loggers + for name, submodule in quantized_model.named_modules(): + if isinstance(submodule, MeanShadowLogger): + submodule.clear() diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/_equalize.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..5d79f7f71b4f2e39ba62ffac449c6be31b40d4a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/_equalize.py @@ -0,0 +1,278 @@ +# mypy: allow-untyped-defs +import copy +from itertools import chain +from typing import Any + +import torch + + +__all__ = [ + "set_module_weight", + "set_module_bias", + "has_bias", + "get_module_weight", + "get_module_bias", + "max_over_ndim", + "min_over_ndim", + "channel_range", + "get_name_by_module", + "cross_layer_equalization", + "process_paired_modules_list_to_name", + "expand_groups_in_paired_modules_list", + "equalize", + "converged", +] + +_supported_types = {torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d} +_supported_intrinsic_types = { + torch.ao.nn.intrinsic.ConvReLU2d, + torch.ao.nn.intrinsic.LinearReLU, + torch.ao.nn.intrinsic.ConvReLU1d, +} +_all_supported_types = _supported_types.union(_supported_intrinsic_types) + + +def set_module_weight(module, weight) -> None: + if type(module) in _supported_types: + module.weight = torch.nn.Parameter(weight) + else: + module[0].weight = torch.nn.Parameter(weight) + + +def set_module_bias(module, bias) -> None: + if type(module) in _supported_types: + module.bias = torch.nn.Parameter(bias) + else: + module[0].bias = torch.nn.Parameter(bias) + + +def has_bias(module) -> bool: + if type(module) in _supported_types: + return module.bias is not None + else: + return module[0].bias is not None + + +def get_module_weight(module): + if type(module) in _supported_types: + return module.weight + else: + return module[0].weight + + +def get_module_bias(module): + if type(module) in _supported_types: + return module.bias + else: + return module[0].bias + + +def max_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.max' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.max(axis, keepdim) + return input + + +def min_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.min' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.min(axis, keepdim) + return input + + +def channel_range(input, axis=0): + """Find the range of weights associated with a specific channel.""" + size_of_tensor_dim = input.ndim + axis_list = list(range(size_of_tensor_dim)) + axis_list.remove(axis) + + mins = min_over_ndim(input, axis_list) + maxs = max_over_ndim(input, axis_list) + + assert mins.size(0) == input.size(axis), ( + "Dimensions of resultant channel range does not match size of requested axis" + ) + return maxs - mins + + +def get_name_by_module(model, module): + """Get the name of a module within a model. + + Args: + model: a model (nn.module) that equalization is to be applied on + module: a module within the model + + Returns: + name: the name of the module within the model + """ + for name, m in model.named_modules(): + if m is module: + return name + raise ValueError("module is not in the model") + + +def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): + """Scale the range of Tensor1.output to equal Tensor2.input. + + Given two adjacent tensors', the weights are scaled such that + the ranges of the first tensors' output channel are equal to the + ranges of the second tensors' input channel + """ + if ( + type(module1) not in _all_supported_types + or type(module2) not in _all_supported_types + ): + raise ValueError( + "module type not supported:", type(module1), " ", type(module2) + ) + + bias = get_module_bias(module1) if has_bias(module1) else None + + weight1 = get_module_weight(module1) + weight2 = get_module_weight(module2) + + if weight1.size(output_axis) != weight2.size(input_axis): + raise TypeError( + "Number of output channels of first arg do not match \ + number input channels of second arg" + ) + + weight1_range = channel_range(weight1, output_axis) + weight2_range = channel_range(weight2, input_axis) + + # producing scaling factors to applied + weight2_range += 1e-9 + scaling_factors = torch.sqrt(weight1_range / weight2_range) + inverse_scaling_factors = torch.reciprocal(scaling_factors) + + if bias is not None: + bias = bias * inverse_scaling_factors + + # formatting the scaling (1D) tensors to be applied on the given argument tensors + # pads axis to (1D) tensors to then be broadcasted + size1 = [1] * weight1.ndim + size1[output_axis] = weight1.size(output_axis) + size2 = [1] * weight2.ndim + size2[input_axis] = weight2.size(input_axis) + + scaling_factors = torch.reshape(scaling_factors, size2) + inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1) + + weight1 = weight1 * inverse_scaling_factors + weight2 = weight2 * scaling_factors + + set_module_weight(module1, weight1) + if bias is not None: + set_module_bias(module1, bias) + set_module_weight(module2, weight2) + + +def process_paired_modules_list_to_name(model, paired_modules_list): + """Processes a list of paired modules to a list of names of paired modules.""" + + for group in paired_modules_list: + for i, item in enumerate(group): + if isinstance(item, torch.nn.Module): + group[i] = get_name_by_module(model, item) + elif not isinstance(item, str): + raise TypeError("item must be a nn.Module or a string") + return paired_modules_list + + +def expand_groups_in_paired_modules_list(paired_modules_list): + """Expands module pair groups larger than two into groups of two modules.""" + new_list = [] + + for group in paired_modules_list: + if len(group) == 1: + raise ValueError("Group must have at least two modules") + elif len(group) == 2: + new_list.append(group) + elif len(group) > 2: + new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1)) + + return new_list + + +def equalize(model, paired_modules_list, threshold=1e-4, inplace=True): + """Equalize modules until convergence is achieved. + + Given a list of adjacent modules within a model, equalization will + be applied between each pair, this will repeated until convergence is achieved + + Keeps a copy of the changing modules from the previous iteration, if the copies + are not that different than the current modules (determined by converged_test), + then the modules have converged enough that further equalizing is not necessary + + Reference is section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf + + Args: + model: a model (nn.Module) that equalization is to be applied on + paired_modules_list (List(List[nn.module || str])): a list of lists + where each sublist is a pair of two submodules found in the model, + for each pair the two modules have to be adjacent in the model, + with only piece-wise-linear functions like a (P)ReLU or LeakyReLU in between + to get expected results. + The list can contain either modules, or names of modules in the model. + If you pass multiple modules in the same list, they will all be equalized together. + threshold (float): a number used by the converged function to determine what degree + of similarity between models is necessary for them to be called equivalent + inplace (bool): determines if function is inplace or not + """ + + paired_modules_list = process_paired_modules_list_to_name( + model, paired_modules_list + ) + + if not inplace: + model = copy.deepcopy(model) + + paired_modules_list = expand_groups_in_paired_modules_list(paired_modules_list) + + name_to_module: dict[str, torch.nn.Module] = {} + previous_name_to_module: dict[str, Any] = {} + name_set = set(chain.from_iterable(paired_modules_list)) + + for name, module in model.named_modules(): + if name in name_set: + name_to_module[name] = module + previous_name_to_module[name] = None + while not converged(name_to_module, previous_name_to_module, threshold): + for pair in paired_modules_list: + previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]]) + previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]]) + + cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]]) + + return model + + +def converged(curr_modules, prev_modules, threshold=1e-4): + """Test whether modules are converged to a specified threshold. + + Tests for the summed norm of the differences between each set of modules + being less than the given threshold + + Takes two dictionaries mapping names to modules, the set of names for each dictionary + should be the same, looping over the set of names, for each name take the difference + between the associated modules in each dictionary + + """ + if curr_modules.keys() != prev_modules.keys(): + raise ValueError( + "The keys to the given mappings must have the same set of names of modules" + ) + + summed_norms = torch.tensor(0.0) + if None in prev_modules.values(): + return False + for name in curr_modules.keys(): + curr_weight = get_module_weight(curr_modules[name]) + prev_weight = get_module_weight(prev_modules[name]) + + difference = curr_weight.sub(prev_weight) + summed_norms += torch.norm(difference) + return bool(summed_norms < threshold) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/_learnable_fake_quantize.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/_learnable_fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..d12c96f66c0092a3a39b9a6411e24c16a3b0372d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/_learnable_fake_quantize.py @@ -0,0 +1,201 @@ +# mypy: allow-untyped-defs + +import torch +from torch.nn.parameter import Parameter + + +__all__: list[str] = [] + + +class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): + r"""Generalized extension of the FakeQuantize module in fake_quantize.py. + + This is an extension of the FakeQuantize module in fake_quantize.py, which + supports more generalized lower-bit quantization and supports learning of the scale + and zero point parameters through backpropagation. + + In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize + module also includes the following attributes to support quantization parameter learning. + + * :attr:`channel_len` defines the length of the channel when initializing scale and zero point + for the per channel case. + + * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are + normalized by the constant, which is proportional to the square root of the number of + elements in the tensor. The related literature justifying the use of this particular constant + can be found here: https://openreview.net/pdf?id=rkgO66VKDS. + + * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output. + + * :attr:`static_enabled` defines the flag for using observer's static estimation for + scale and zero point. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. + """ + + def __init__( + self, + observer, + quant_min=0, + quant_max=255, + scale=1.0, + zero_point=0.0, + channel_len=-1, + use_grad_scaling=False, + **observer_kwargs, + ): + super().__init__() + assert quant_min < quant_max, "quant_min must be strictly less than quant_max." + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + self.use_grad_scaling = use_grad_scaling + if channel_len == -1: + self.scale = Parameter(torch.tensor([scale])) + self.zero_point = Parameter(torch.tensor([zero_point])) + else: + assert isinstance(channel_len, int) and channel_len > 0, ( + "Channel size must be a positive integer." + ) + self.scale = Parameter(torch.tensor([scale] * channel_len)) + self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) + + self.activation_post_process = observer(**observer_kwargs) + assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, ( + "quant_min out of bound" + ) + assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, ( + "quant_max out of bound" + ) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("static_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8)) + + bitrange = torch.tensor(quant_max - quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps])) + + @torch.jit.export + def enable_param_learning(self): + r"""Enable parameter learning over static observer estimates. + + Enables learning of quantization parameters and + disables static observer estimates. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=True).toggle_fake_quant( + enabled=True + ).toggle_observer_update(enabled=False) + return self + + @torch.jit.export + def enable_static_estimate(self): + """Enable static estimates of quantization parameters. + + Enables static observer estimates and disables learning of + quantization parameters. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False).toggle_fake_quant( + enabled=True + ).toggle_observer_update(enabled=True) + + @torch.jit.export + def enable_static_observation(self): + """Enable accumulation of data without updating quantization parameters. + + Enables static observer accumulating data from input but doesn't + update the quantization parameters. Forward path returns the original X. + """ + self.toggle_qparam_learning(enabled=False).toggle_fake_quant( + enabled=False + ).toggle_observer_update(enabled=True) + + @torch.jit.export + def toggle_observer_update(self, enabled=True): + self.static_enabled[0] = int(enabled) # type: ignore[operator] + return self + + @torch.jit.export + def enable_observer(self, enabled=True): + self.toggle_observer_update(enabled) + + @torch.jit.export + def toggle_qparam_learning(self, enabled=True): + self.learning_enabled[0] = int(enabled) # type: ignore[operator] + self.scale.requires_grad = enabled + self.zero_point.requires_grad = enabled + return self + + @torch.jit.export + def toggle_fake_quant(self, enabled=True): + self.fake_quant_enabled[0] = int(enabled) + return self + + @torch.jit.export + def observe_quant_params(self): + print(f"_LearnableFakeQuantize Scale: {self.scale.detach()}") + print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}") + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + scale = self.scale.detach() + zero_point = ( + self.zero_point.detach() + .round() + .clamp(self.quant_min, self.quant_max) + .long() + ) + return scale, zero_point + + def forward(self, X): + if self.static_enabled[0] == 1: # type: ignore[index] + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point) + else: + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + + if self.fake_quant_enabled[0] == 1: + if self.qscheme in ( + torch.per_channel_symmetric, + torch.per_tensor_symmetric, + ): + self.zero_point.data.zero_() + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 + else: + grad_factor = 1.0 + if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.quant_min, + self.quant_max, + grad_factor, + ) + else: + X = torch._fake_quantize_learnable_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.quant_min, + self.quant_max, + grad_factor, + ) + + return X diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fake_quantize.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..c17008adcf6518299e1568f39cd90086a1519b3f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fake_quantize.py @@ -0,0 +1,650 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +"""Implements modules used to perform fake quantization.""" + +import re +from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch.ao.quantization.observer import ( + _with_args, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + FixedQParamsObserver, + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, +) +from torch.nn import Module + + +__all__ = [ + "FakeQuantizeBase", + "FakeQuantize", + "FixedQParamsFakeQuantize", + "FusedMovingAvgObsFakeQuantize", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "default_fake_quant", + "default_weight_fake_quant", + "default_dynamic_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_symmetric_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_fake_quant", + "default_per_channel_weight_fake_quant", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_histogram_fake_quant", + "default_fused_act_fake_quant", + "default_fused_wt_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "fused_wt_fake_quant_range_neg_127_to_127", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", +] + + +def _is_per_channel(qscheme: "torch.qscheme") -> bool: + return qscheme in [ + torch.per_channel_symmetric, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + ] + + +def _is_per_tensor(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + + +def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + + +def _is_float_qparams(qscheme: "torch.qscheme") -> bool: + return qscheme in [ + torch.per_channel_affine_float_qparams, + ] + + +class FakeQuantizeBase(ABC, Module): + r"""Base fake quantize module. + + Base fake quantize module + Any fake quantize implementation should derive from this class. + + Concrete fake quantize module should follow the same API. In forward, they will update + the statistics of the observed Tensor and fake quantize the input. They should also provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + """ + + fake_quant_enabled: torch.Tensor + observer_enabled: torch.Tensor + + def __init__(self) -> None: + """Set fake_quant_enabled and observer_enabled.""" + super().__init__() + # fake_quant_enabled and observer_enabled are buffers to support their + # replication in DDP. Data type is uint8 because NCCL does not support + # bool tensors. + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8)) + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + @torch.jit.export + def enable_fake_quant(self, enabled: bool = True) -> None: + self.fake_quant_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_fake_quant(self): + self.enable_fake_quant(False) + + @torch.jit.export + def enable_observer(self, enabled: bool = True) -> None: + self.observer_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_observer(self): + self.enable_observer(False) + + @classmethod + def with_args(cls, **kwargs): + fake_quant_constructor = _with_args(cls, **kwargs) + # need to assign the correct module to fake_quantize + # constructors to satisfy public v private requirements + fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize" + return fake_quant_constructor + + +class FakeQuantize(FakeQuantizeBase): + r"""Simulate the quantize and dequantize operations in training time. + + The output of this module is given by:: + + x_out = ( + clamp(round(x / scale + zero_point), quant_min, quant_max) - zero_point + ) * scale + + * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization + operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq) + + * :attr:`scale` defines the scale factor used for quantization. + + * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to + + * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that + statistics can still be updated. + + * :attr:`observer_enabled` controls statistics collection on tensors + + * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, + allowable values are torch.qint8 and torch.quint8. + + Args: + + observer (module): Module for observing statistics on input tensors and calculating scale + and zero-point. + observer_kwargs (optional): Arguments for the observer module + + Attributes: + activation_post_process (Module): User provided module that collects statistics on the input tensor and + provides a method to calculate scale and zero-point. + + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + observer=MovingAverageMinMaxObserver, + quant_min=None, + quant_max=None, + is_dynamic=False, + **observer_kwargs, + ): + super().__init__() + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + assert quant_min <= quant_max, ( + "quant_min must be less than or equal to quant_max" + ) + dtype = observer_kwargs.get("dtype", torch.quint8) + if hasattr(observer, "p"): + # In case observer is _PartialWrapper, dtype can be stored in + # observer.p.keywords["dtype"] + dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( + "dtype", dtype + ) + assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" + assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" + observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) + observer_kwargs["is_dynamic"] = is_dynamic + self.activation_post_process = observer(**observer_kwargs) + # TODO: keeping self.quant_min/max for BC; remove after a couple releases + # Users should use self.activation_post_process.quant_min + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + self.is_dynamic = self.activation_post_process.is_dynamic + if _is_float_qparams(self.activation_post_process.qscheme): + zero_point_dtype = torch.float + else: + zero_point_dtype = torch.int + self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype)) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = ( + self.activation_post_process.ch_axis + if hasattr(self.activation_post_process, "ch_axis") + else -1 + ) + assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), ( + "Only per channel and per tensor quantization are supported in fake quantize" + + " got qscheme: " + + str(self.qscheme) + ) + self.is_per_channel = _is_per_channel(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.activation_post_process.calculate_qparams() + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = ( + _scale.to(self.scale.device), + _zero_point.to(self.zero_point.device), + ) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, + self.scale, + self.zero_point, + self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + else: + X = torch.fake_quantize_per_tensor_affine( + X, + self.scale, + self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + ) + return X + + @torch.jit.export + def extra_repr(self): + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"scale={self.scale}, zero_point={self.zero_point}" + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to manually + # specify serialization here. + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "scale"] = self.scale + destination[prefix + "zero_point"] = self.zero_point + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # Removing this function throws an error that the size of the loaded tensor does not match the original size + # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. + local_state = ["scale", "zero_point"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == "scale": + self.scale.resize_(val.shape) + else: + assert name == "zero_point" + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == "scale": + self.scale.copy_(val) + else: + assert name == "zero_point" + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + +class FixedQParamsFakeQuantize(FakeQuantize): + """Simulate quantize and dequantize in training time. + + Simulate quantize and dequantize with fixed quantization + parameters in training time. Only per tensor quantization + is supported. + """ + + # TODO: rename observer to observer_ctr + def __init__(self, observer): + super().__init__(observer=observer) + assert type(self.activation_post_process) == FixedQParamsObserver, ( + f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + ) + self._observer_ctr = observer + self.scale = self.activation_post_process.scale + self.zero_point = self.activation_post_process.zero_point + assert _is_per_tensor(self.qscheme), ( + "Only per tensor quantization is supported" + + " FixedQParamsFakeQuantize module, got qscheme:" + + str(self.qscheme) + ) + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self): + """Define a string representation of the object's attributes.""" + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, " + f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, " + f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}" + ) + + +class FusedMovingAvgObsFakeQuantize(FakeQuantize): + r"""Define a fused module to observe the tensor. + + Fused module that is used to observe the input tensor (compute min/max), compute + scale/zero_point and fake_quantize the tensor. + This module uses calculation similar MovingAverageMinMaxObserver for the inputs, + to compute the min/max values in order to compute the scale/zero_point. + The qscheme input in the observer is used to differentiate between symmetric/affine + quantization scheme. + + The output of this module is given by + x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale + + Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the + base class. + + """ + + def __init__( + self, + observer: Any = MovingAverageMinMaxObserver, + quant_min: int = 0, + quant_max: int = 255, + **observer_kwargs: Any, + ) -> None: + super().__init__(observer, quant_min, quant_max, **observer_kwargs) + assert isinstance( + self.activation_post_process, + (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver), + ), ( + "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + ) + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) + self.is_symmetric_quant = _is_symmetric_quant( + self.activation_post_process.qscheme + ) + + @torch.jit.export + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + return self.activation_post_process.calculate_qparams() + + @torch.jit.export + def extra_repr(self) -> str: + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}" + ) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + return torch.fused_moving_avg_obs_fake_quant( + X, + self.observer_enabled, + self.fake_quant_enabled, + self.activation_post_process.min_val, + self.activation_post_process.max_val, + self.scale, + self.zero_point, + self.activation_post_process.averaging_constant, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + self.ch_axis, + self.is_per_channel, + self.is_symmetric_quant, + ) + + +default_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Default fake_quant for activations. +""" + +default_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=False, +) +""" +Default fake_quant for weights. +Observer is memoryless since averaging_constant is 1. +""" + +default_dynamic_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + is_dynamic=True, + dtype=torch.quint8, + averaging_constant=1, +) +""" +Default dynamic fake_quant for activations. +""" + +default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_neg1to1_observer +) +default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args( + observer=default_fixed_qparams_range_0to1_observer +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_fake_quant = ( + default_fixed_qparams_range_neg1to1_fake_quant +) +default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant + +default_per_channel_weight_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + ch_axis=0, +) +""" +Default fake_quant for per-channel weights. +Observer is memoryless since averaging_constant is 1. +""" +default_embedding_fake_quant = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + dtype=torch.quint8, + quant_min=0, + quant_max=255, + ch_axis=0, + averaging_constant=1, +) +""" +Default fake_quant for embeddings. +Observer is memoryless since averaging_constant is 1. +""" + +default_embedding_fake_quant_4bit = FakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + dtype=torch.quint4x2, + averaging_constant=1, +) + +default_histogram_fake_quant = FakeQuantize.with_args( + observer=HistogramObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, +) +""" +Fake_quant for activations using a histogram.. +""" + + +default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, +) + +""" +Fused version of `default_fake_quant`, with improved performance. +""" + + +default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, +) +""" +Fused version of `default_weight_fake_quant`, with improved performance. +""" + +default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, +) +""" +Fused version of `default_per_channel_weight_fake_quant`, with improved performance. +""" + +fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + eps=2**-12, +) +""" +Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +fused_per_channel_wt_fake_quant_range_neg_127_to_127 = ( + FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + eps=2**-12, + ) +) + +""" +Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + + +def _is_fake_quant_script_module(mod): + """Return true if given mod is an instance of FakeQuantize script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return ( + name == "torch.ao.quantization.fake_quantize.FakeQuantize" + or name + == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize" + ) + return False + + +def disable_fake_quant(mod): + """Disable fake quantization for the module. + + Disable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_fake_quant() + + +def enable_fake_quant(mod): + """Enable fake quantization for the module. + + Enable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_fake_quant() + + +def disable_observer(mod): + """Disable observation for this module. + + Disable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_observer() + + +def enable_observer(mod): + """Enable observation for this module. + + Enable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_observer() diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fuse_modules.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fuse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d151858c7b8c0c34e995e03839aab89290b66d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fuse_modules.py @@ -0,0 +1,216 @@ +# mypy: allow-untyped-defs +import copy +from typing import Optional + +import torch.nn as nn + +# for backward compatibility +from torch.ao.quantization.fuser_method_mappings import ( # noqa: F401 # noqa: F401 + fuse_conv_bn, + fuse_conv_bn_relu, + get_fuser_method, +) +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "fuse_known_modules", + "fuse_modules", + "fuse_modules_qat", +] + + +# Generalization of getattr +def _get_module(model, submodule_key): + tokens = submodule_key.split(".") + cur_mod = model + for s in tokens: + cur_mod = getattr(cur_mod, s) + return cur_mod + + +# Generalization of setattr +def _set_module(model, submodule_key, module): + tokens = submodule_key.split(".") + sub_tokens = tokens[:-1] + cur_mod = model + for s in sub_tokens: + cur_mod = getattr(cur_mod, s) + + setattr(cur_mod, tokens[-1], module) + + +def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): + r"""Return a list of known fuse modules. + + Returns a list of modules that fuses the operations specified + in the input module list. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, bn + linear, relu + For these sequences, the first element in the output module list performs + the fused operation. The rest of the elements are set to nn.Identity() + """ + types = tuple(type_before_parametrizations(m) for m in mod_list) + fuser_method = get_fuser_method(types, additional_fuser_method_mapping) + if fuser_method is None: + raise NotImplementedError(f"Cannot fuse modules: {types}") + new_mod: list[Optional[nn.Module]] = [None] * len(mod_list) + fused = fuser_method(is_qat, *mod_list) + # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion + # Move pre forward hooks of the base module to resulting fused module + for pre_hook_fn in mod_list[0]._forward_pre_hooks.values(): + fused.register_forward_pre_hook(pre_hook_fn) + mod_list[0]._forward_pre_hooks.clear() + # Move post forward hooks of the last module to resulting fused module + for hook_fn in mod_list[-1]._forward_hooks.values(): + fused.register_forward_hook(hook_fn) + mod_list[-1]._forward_hooks.clear() + new_mod[0] = fused + + for i in range(1, len(mod_list)): + identity = nn.Identity() + identity.training = mod_list[0].training + new_mod[i] = identity + + return new_mod + + +def _fuse_modules_helper( + model, + modules_to_fuse, + is_qat, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get( + "additional_fuser_method_mapping", {} + ) + mod_list = [_get_module(model, item) for item in modules_to_fuse] + + # Fuse list of modules + new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) + + # Replace original module list with fused module list + for i, item in enumerate(modules_to_fuse): + _set_module(model, item, new_mod_list[i]) + + +def _fuse_modules( + model, + modules_to_fuse, + is_qat, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + if not inplace: + model = copy.deepcopy(model) + + if all(isinstance(module_element, str) for module_element in modules_to_fuse): + # Handle case of modules_to_fuse being a list + _fuse_modules_helper( + model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict + ) + else: + # Handle case of modules_to_fuse being a list of lists + for module_list in modules_to_fuse: + _fuse_modules_helper( + model, module_list, is_qat, fuser_func, fuse_custom_config_dict + ) + return model + + +def fuse_modules( + model, + modules_to_fuse, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + r"""Fuse a list of modules into a single module. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, relu + bn, relu + All other sequences are left unchanged. + For these sequences, replaces the first item in the list + with the fused module, replacing the rest of the modules + with identity. + + Args: + model: Model containing the modules to be fused + modules_to_fuse: list of list of module names to fuse. Can also be a list + of strings if there is only a single list of modules to fuse. + inplace: bool specifying if fusion happens in place on the model, by default + a new model is returned + fuser_func: Function that takes in a list of modules and outputs a list of fused modules + of the same length. For example, + fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] + Defaults to torch.ao.quantization.fuse_known_modules + `fuse_custom_config_dict`: custom configuration for fusion + + .. code-block:: python + + # Example of fuse_custom_config_dict + fuse_custom_config_dict = { + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + } + + Returns: + model with fused modules. A new copy is created if inplace=True. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = M().eval() + >>> # m is a module containing the sub-modules below + >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + >>> m = M().eval() + >>> # Alternately provide a single list of modules to fuse + >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + """ + return _fuse_modules( + model, + modules_to_fuse, + is_qat=False, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict, + ) + + +def fuse_modules_qat( + model, + modules_to_fuse, + inplace=False, + fuser_func=fuse_known_modules, + fuse_custom_config_dict=None, +): + """QAT version for `fuse_modules`.""" + return _fuse_modules( + model, + modules_to_fuse, + is_qat=True, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/fuser_method_mappings.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fuser_method_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..260bbee37bd2bd1c8b33a175842bb4ebc3251ab4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/fuser_method_mappings.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-defs +import itertools +from typing import Any, Callable, Optional, Union + +import torch.ao.nn.intrinsic as nni +import torch.nn as nn +from torch.ao.quantization.utils import get_combined_dict, MatchAllNode, Pattern + + +__all__ = [ + "fuse_conv_bn", + "fuse_conv_bn_relu", + "fuse_linear_bn", + "fuse_convtranspose_bn", + "get_fuser_method", + "get_fuser_method_new", +] + + +def fuse_conv_bn(is_qat, conv, bn): + r"""Return the fused the conv and bn modules. + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn(m1, b1) + """ + assert conv.training == bn.training, ( + "Conv and BN both must be in the same mode (train or eval)." + ) + + fused_module_class_map = { + nn.Conv1d: nni.ConvBn1d, + nn.Conv2d: nni.ConvBn2d, + nn.Conv3d: nni.ConvBn3d, + } + + if is_qat: + assert bn.num_features == conv.out_channels, ( + "Output channel of Conv2d must match num_features of BatchNorm2d" + ) + assert bn.affine, "Only support fusing BatchNorm2d with affine set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm2d with tracking_running_stats set to True" + ) + fused_module_class = fused_module_class_map.get((type(conv)), None) + if fused_module_class is not None: + return fused_module_class(conv, bn) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}") + else: + return nn.utils.fuse_conv_bn_eval(conv, bn) + + +def fuse_conv_bn_relu(is_qat, conv, bn, relu): + r"""Return the fused conv and bv modules. + + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> r1 = nn.ReLU(inplace=False) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn_relu(m1, b1, r1) + """ + assert conv.training == bn.training == relu.training, ( + "Conv and BN both must be in the same mode (train or eval)." + ) + fused_module: Optional[type[nn.Sequential]] = None + if is_qat: + map_to_fused_module_train = { + nn.Conv1d: nni.ConvBnReLU1d, + nn.Conv2d: nni.ConvBnReLU2d, + nn.Conv3d: nni.ConvBnReLU3d, + } + assert bn.num_features == conv.out_channels, ( + "Output channel of Conv must match num_features of BatchNorm" + ) + assert bn.affine, "Only support fusing BatchNorm with affine set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm with tracking_running_stats set to True" + ) + fused_module = map_to_fused_module_train.get(type(conv), None) + if fused_module is not None: + return fused_module(conv, bn, relu) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}") + else: + map_to_fused_module_eval = { + nn.Conv1d: nni.ConvReLU1d, + nn.Conv2d: nni.ConvReLU2d, + nn.Conv3d: nni.ConvReLU3d, + } + fused_module = map_to_fused_module_eval.get(type(conv), None) + if fused_module is not None: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return fused_module(fused_conv, relu) + else: + raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}") + + +def fuse_linear_bn(is_qat, linear, bn): + r"""Return the fused linear and bn modules. + Given the linear and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + + Examples:: + + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> # xdoctest: +SKIP + >>> m2 = fuse_linear_bn(m1, b1) + """ + assert linear.training == bn.training, ( + "Linear and BN both must be in the same mode (train or eval)." + ) + + if is_qat: + assert bn.num_features == linear.out_features, ( + "Output features of Linear must match num_features of BatchNorm1d" + ) + assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" + assert bn.track_running_stats, ( + "Only support fusing BatchNorm1d with tracking_running_stats set to True" + ) + return nni.LinearBn1d(linear, bn) + else: + return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + + +def fuse_convtranspose_bn(is_qat, convt, bn): + r"""Return the fused ConvTranspose and bn modules. + Given ConvTranspose and bn modules, fuses them and returns the fused module + + Args: + convt: Module instance of type ConvTransposeNd + bn: BatchNormNd instance that needs to be fused with the linear layer. + batch norm N should match the ConvTranspose N + + Examples:: + + >>> m1 = nn.ConvTranspose2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_convtranspose_bn(m1, b1) + """ + assert convt.training == bn.training, ( + "ConvTranspose and BN both must be in the same mode (train or eval)." + ) + + if is_qat: + raise Exception( # noqa: TRY002 + "Fusing ConvTranspose+BatchNorm not yet supported in QAT." + ) + else: + return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) + + +def _sequential_wrapper2(sequential): + """Return a sequential wrapped that for is_qat and two modules. + Given a sequential class for two modules, return a function that takes + is_qat, and then two modules as argument, that ignores the is_qat flag + and always returns the sequential that combines the two input modules + """ + + def fuser_method(is_qat, m1, m2): + return sequential(m1, m2) + + return fuser_method + + +_DEFAULT_OP_LIST_TO_FUSER_METHOD: dict[tuple, Union[nn.Sequential, Callable]] = { + (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, + (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, + (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, + (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d), + (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d), + (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d), + (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, + (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU), + (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d), + (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d), + (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, + (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, + (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, +} + + +def get_fuser_method(op_list, additional_fuser_method_mapping=None): + """Get fuser method for the given list of module types. + + Get fuser method for the given list of module types, + return None if fuser method does not exist + """ + if additional_fuser_method_mapping is None: + additional_fuser_method_mapping = {} + all_mappings = get_combined_dict( + _DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping + ) + fuser_method = all_mappings.get(op_list, None) + assert fuser_method is not None, f"did not find fuser method for: {op_list} " + return fuser_method + + +def _reverse2(f): + def reversed(is_qat, x, y): + return f(is_qat, y, x) + + return reversed + + +def _reverse3(f): + def reversed(is_qat, x, w): + y, z = w + return f(is_qat, z, y, x) + + return reversed + + +def _get_valid_patterns(op_pattern): + """Return a list of valid patterns generated from the op_pattern. + + Returns a list of valid patterns generated from the op_pattern, + since MatchAllNode can match all types of nodes, + e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like + (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode) + + Example Input: + (torch.add, (torch.nn.ReLU, torch.nn.Conv2d)) + + Example Output: + [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)), + (torch.add, (torch.nn.ReLU, MatchAllNode)), + (torch.add, (MatchAllNode, torch.nn.Conv2d)), + (torch.add, (MatchAllNode, MatchAllNode)), + (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)), + (MatchAllNode, (torch.nn.ReLU, MatchAllNode)), + (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)), + (MatchAllNode, (MatchAllNode, MatchAllNode)), + ] + """ + result: list[Any] + if isinstance(op_pattern, (tuple, list)): + sub_combs = [_get_valid_patterns(sub_pattern) for sub_pattern in op_pattern] + result = list(itertools.product(*sub_combs)) + else: + result = [op_pattern, MatchAllNode] + return result + + +def get_fuser_method_new( + op_pattern: Pattern, + fuser_method_mapping: dict[Pattern, Union[nn.Sequential, Callable]], +): + """Get fuser method. + + This will be made default after we deprecate the get_fuser_method + Would like to implement this first and have a separate PR for deprecation + """ + op_patterns = _get_valid_patterns(op_pattern) + fuser_method = None + for op_pattern in op_patterns: + fuser_method = fuser_method_mapping.get(op_pattern, None) + if fuser_method is not None: + break + assert fuser_method is not None, f"did not find fuser method for: {op_pattern} " + return fuser_method diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/observer.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd3145f6bdec446cdae55ea11e86ab53455abb6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/observer.py @@ -0,0 +1,2131 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# temporarily skip RUF for this file for now, we can re-enable +# after move the affine quantization related things to torchao +# noqa: RUF +""" +This module implements observers which are used to collect statistics about +the values observed during calibration (PTQ) or training (QAT). +""" + +import operator +import re +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from functools import partial +from typing import Any, Optional + +import torch +import torch.nn as nn +from torch.ao.quantization.utils import ( + calculate_qmin_qmax, + check_min_max_valid, + is_per_channel, + is_per_tensor, + validate_qmin_qmax, +) +from torch.fx import Node + + +__all__ = [ + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_quant_observer", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_observer", + "default_weight_observer", + "get_observer_state_dict", + "load_observer_state_dict", + "per_channel_weight_observer_range_neg_127_to_127", + "weight_observer_range_neg_127_to_127", + "FixedQParamsObserver", + "HistogramObserver", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", +] + + +class _PartialWrapper: + def __init__(self, p): + self.p = p + self.callable_args = {} + + def __call__(self, *args, **keywords): + # call each arg in callable_args and add them partial, then run with keywords + # skip if arg_name in keywords so its possible to overwrite + for arg_name in self.callable_args: + if arg_name not in keywords: + keywords = {**keywords, arg_name: self.callable_args[arg_name]()} + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + self.callable_args.__repr__() + + def with_args(self, **kwargs): + return _with_args(self, **kwargs) + + def with_callable_args(self, **kwargs): + result = _PartialWrapper(p=self.p) + result.callable_args = {**self.callable_args, **kwargs} + return result + + +def _with_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. Can be used in conjunction with + _callable_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + r = _PartialWrapper(partial(cls_or_self, **kwargs)) + return r + + +def _with_callable_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories args that need to be + called at construction time. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances and those arguments should only + be calculated at construction time. Can be used in conjunction with _with_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_callable_args = classmethod(_with_callable_args) + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") + >>> foo_instance1 = foo_builder() + >>> # wait 50 + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) + False + """ + r = _PartialWrapper(partial(cls_or_self)) + return r.with_callable_args(**kwargs) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + + +class ObserverBase(ABC, nn.Module): + r"""Base observer Module. + Any observer implementation should derive from this class. + + Concrete observers should follow the same API. In forward, they will update + the statistics of the observed Tensor. And they should provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization + or static quantization + """ + + def __init__(self, dtype, is_dynamic: bool = False): + super().__init__() + self.dtype = dtype + self.is_dynamic = is_dynamic + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + with_args = classmethod(_with_args) + with_callable_args = classmethod(_with_callable_args) + + +class UniformQuantizationObserverBase(ObserverBase): + r"""Common base for all observers using uniform quantization to calculate + scale and zero_point. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + .. warning:: + + :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + or `torch.int8` or `torch.uint8` + + .. warning:: + + :attr:`qscheme` can only take one of the following options: + + - ``torch.per_tensor_affine`` + - ``torch.per_tensor_symmetric`` + - ``torch.per_channel_affine`` + - ``torch.per_channel_symmetric`` + """ + + # Note: the version is shared by all observer types + # + # Version 1/None + # self + # + # Version 2 (base class only, does not include child class buffers) + # self + # |--- eps : Tensor + # + # Version 3 + # for HistogramObserver only, changed the shape of uninitialized + # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) + # for PerChannelObservers, changed the name of the buffers from min_vals + # to min_val and from max_vals to max_val. + _version = 3 + + eps: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.qscheme = qscheme + if reduce_range: + warnings.warn( + "Please use quant_min and quant_max to specify the range for observers. \ + reduce_range will be deprecated in a future release of PyTorch." + ) + self.reduce_range = reduce_range + self.register_buffer("eps", torch.tensor([eps], **factory_kwargs)) + assert self.qscheme in ( + torch.per_tensor_affine, + torch.per_tensor_symmetric, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + ), ( + "Default Observer only works for per_tensor_affine, \ + per_tensor_symmetric, per_channel_affine, \ + per_channel_symmetric and per_channel_float_qparams quantization scheme" + ) + + _ALLOWED_DTYPES = ( + torch.qint8, + torch.quint8, + torch.quint4x2, + torch.qint32, + torch.int8, + torch.uint8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.uint16, + ) + + assert self.dtype in _ALLOWED_DTYPES, ( + f"Default Observer only works for {_ALLOWED_DTYPES} data type" + ) + self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) + if self.has_customized_qrange: + validate_qmin_qmax(quant_min, quant_max) + self.quant_min, self.quant_max = calculate_qmin_qmax( + quant_min, + quant_max, + self.has_customized_qrange, + self.dtype, + self.reduce_range, + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version == 1: + # eps was moved to a buffer in version 2 + eps = torch.tensor([torch.finfo(torch.float32).eps]) + state_dict[prefix + "eps"] = eps + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert quant_min <= 0 <= quant_max, ( + "Used-specified quantization range must include 0." + ) + assert quant_min < quant_max, ( + "qmin must be strictly less than qmax for user-specified quantization range." + ) + + @torch.jit.export + def _calculate_qparams( + self, min_val: torch.Tensor, max_val: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme + # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer + # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code + # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. + # TODO(jakeszwe, jerryzh168) + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + quant_min, quant_max = self.quant_min, self.quant_max + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + if ( + self.qscheme == torch.per_tensor_symmetric + or self.qscheme == torch.per_channel_symmetric + ): + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, self.eps) + if self.dtype in [torch.quint8, torch.uint8]: + if self.has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.dtype in [torch.uint16]: + zero_point = zero_point.new_full(zero_point.size(), 2**15) + elif self.qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, self.eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if self.qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale, zero_point + + @torch.jit.export + def reset_min_max_vals(self): + raise NotImplementedError("Cannot reset min/max values in the given observer.") + + +# Originally, this class was called `_ObserverBase`. Keeping the old name around +# for backwards compatibility. +# TODO(after v1.13): delete this +_ObserverBase = UniformQuantizationObserverBase + + +class MinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running min and max values. + + This observer uses the tensor min/max statistics to compute the quantization + parameters. The module records the running minimum and maximum of incoming + tensors, and uses this statistic to compute the quantization parameters. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, + scale :math:`s` and zero point :math:`z` are computed as: + + The running minimum/maximum :math:`x_\text{min/max}` is computed as: + + .. math:: + + \begin{array}{ll} + x_\text{min} &= \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} + \end{cases}\\ + x_\text{max} &= \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`X` is the observed tensor. + + The scale :math:`s` and zero point :math:`z` are then computed as: + + .. math:: + + \begin{aligned} + \text{if Symmetric:}&\\ + &s = 2 \max(|x_\text{min}|, x_\text{max}) / + \left( Q_\text{max} - Q_\text{min} \right) \\ + &z = \begin{cases} + 0 & \text{if dtype is qint8} \\ + 128 & \text{otherwise} + \end{cases}\\ + \text{Otherwise:}&\\ + &s = \left( x_\text{max} - x_\text{min} \right ) / + \left( Q_\text{max} - Q_\text{min} \right ) \\ + &z = Q_\text{min} - \text{round}(x_\text{min} / s) + \end{aligned} + + where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and + maximum of the quantized data type. + + .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but + # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it + # supports dynamic quantization, we may need to better error checking here + + # For x86 quantized kernels, we need to ensure that the vpmaddubsw + # instruction does not overflow. We allow for a reduce_range argument to + # observers that reduces the quantized range to (0,127) or (-64, 63). + # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp + # This is not an optimal choice for non x86 backends as it loses a bit + # of precision for activations. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + if ( + self.qscheme == torch.per_tensor_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric \ + quantization for quint8" + ) + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + r"""Calculates the quantization parameters.""" + return self._calculate_qparams(self.min_val, self.max_val) + + @torch.jit.export + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float("inf"))) + self.max_val.copy_(torch.tensor(float("-inf"))) + + +class MovingAverageMinMaxObserver(MinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + moving average of the min and max values. + + This observer computes the quantization parameters based on the moving + averages of minimums and maximums of the incoming tensors. The module + records the average minimum and maximum of incoming tensors, and uses this + statistic to compute the quantization parameters. + + Args: + averaging_constant: Averaging constant for min/max. + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The moving average min/max is computed as follows + + .. math:: + + \begin{array}{ll} + x_\text{min} = \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + (1 - c) x_\text{min} + c \min(X) & \text{otherwise} + \end{cases}\\ + x_\text{max} = \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + (1 - c) x_\text{max} + c \max(X) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is + is the incoming tensor, and :math:`c` is the ``averaging_constant``. + + The scale and zero point are then computed as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`. + + .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + f"MovingAverageMinMaxObserver's qscheme only support \ + torch.per_tensor_symmetric and torch.per_tensor_affine. \ + but got: {qscheme}" + ) + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + if min_val == float("inf") and max_val == float("-inf"): + min_val, max_val = torch.aminmax(x) + else: + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class PerChannelMinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + ch_axis: Channel axis + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference + that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "PerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "PerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.ch_axis = ch_axis + self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) + self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) + if ( + self.qscheme == torch.per_channel_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric quantization for quint8" + ) + + def forward(self, x_orig): + return self._forward(x_orig) + + def _forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self._calculate_qparams(self.min_val, self.max_val) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + def _load_from_state_dict( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + version = local_metadata.get("version", None) + if version is not None and version < 3: + local_state = ["min_vals", "max_vals"] + expected_min_name = "min_vals" + expected_max_name = "max_vals" + else: + local_state = ["min_val", "max_val"] + expected_min_name = "min_val" + expected_max_name = "max_val" + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading min_val or max_val + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == expected_min_name: + self.min_val.resize_(val.shape) + elif name == expected_max_name: + self.max_val.resize_(val.shape) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}" + ) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == expected_min_name: + self.min_val.copy_(val) + elif name == expected_max_name: + self.max_val.copy_(val) + else: + warnings.warn( + f"Observer load_from_state_dict got unexpected name {name}" + ) + elif strict: + missing_keys.append(key) + + if not torch.jit.is_scripting(): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _load_from_state_dict_script( + self, + state_dict: dict[str, Any], + prefix: str, + local_metadata: dict[str, torch.Tensor], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ): + self._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + # This used to be torch.ones but that does not work because + # JIT compiler can optimize it via common subexpression elimination + # in which case both min_val and max_val point to the same tensor. + self.min_val = torch.rand( + 0, + ) + self.max_val = torch.rand( + 0, + ) + + +class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + averaging_constant: Averaging constant for min/max. + ch_axis: Channel axis + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the + difference that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + ch_axis=ch_axis, + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + self.averaging_constant = averaging_constant + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class HistogramObserver(UniformQuantizationObserverBase): + r""" + The module records the running histogram of tensor values along with + min/max values. ``calculate_qparams`` will calculate scale and zero_point. + + Args: + bins: Number of bins to use for the histogram + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The scale and zero point are computed as follows: + + 1. Create the histogram of the incoming inputs. + The histogram is computed continuously, and the ranges per bin change + with every new tensor observed. + 2. Search the distribution in the histogram for optimal min/max values. + The search for the min/max values ensures the minimization of the + quantization error with respect to the floating point model. + 3. Compute the scale and zero point the same way as in the + :class:`~torch.ao.quantization.MinMaxObserver` + """ + + histogram: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + bins: int = 2048, + dtype: torch.dtype = torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "HistogramObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + if is_dynamic: + raise NotImplementedError( + "HistogramObserver doesn't support dynamic quantization" + ) + # bins: The number of bins used for histogram calculation. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.bins = bins + self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits + self.upsample_rate = ( + 16 # used to reduce quantization errors when upscaling histogram + ) + + def _get_norm( + self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor + ) -> torch.Tensor: + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + Currently only L2 norm is supported. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + norm = ( + delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + bin_width = (self.max_val.item() - self.min_val.item()) / self.bins + + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + if dst_bin_width == 0.0: + return 0.0 + + src_bin = torch.arange(self.bins, device=self.histogram.device) + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = torch.clamp( + torch.div(src_bin_begin, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width + + dst_bin_of_end = torch.clamp( + torch.div(src_bin_end, dst_bin_width, rounding_mode="floor"), + 0, + self.dst_nbins - 1, + ) + density = self.histogram / bin_width + + norm = torch.zeros(self.bins, device=self.histogram.device) + + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm += self._get_norm( + delta_begin, + torch.ones(self.bins, device=self.histogram.device) * delta_end, + density, + ) + + norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( + torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density + ) + + dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) + + return norm.sum().item() + + def _non_linear_param_search(self) -> tuple[torch.Tensor, torch.Tensor]: + r"""Non-linear parameter search. + + An approximation for L2 error minimization for selecting min/max. + By selecting new min/max, we filter out outliers in input distribution. + This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in + caffe2/quantization/server/norm_minimization.cc + """ + assert self.histogram.size()[0] == self.bins, "bins mismatch" + bin_width = (self.max_val - self.min_val) / self.bins + + # cumulative sum + total = torch.sum(self.histogram).item() + cSum = torch.cumsum(self.histogram, dim=0) + + stepsize = 1e-5 # granularity + alpha = 0.0 # lower bound + beta = 1.0 # upper bound + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float("inf") + + while alpha < beta: + # Find the next step + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + l = start_bin + r = end_bin + while l < end_bin and cSum[l] < next_alpha * total: + l = l + 1 + while r > start_bin and cSum[r] > next_beta * total: + r = r - 1 + + # decide the next move + next_start_bin = start_bin + next_end_bin = end_bin + if (l - start_bin) > (end_bin - r): + # move the start bin + next_start_bin = l + alpha = next_alpha + else: + # move the end bin + next_end_bin = r + beta = next_beta + + if next_start_bin == start_bin and next_end_bin == end_bin: + continue + + # calculate the quantization error using next_start_bin and next_end_bin + norm = self._compute_quantization_error(next_start_bin, next_end_bin) + + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin + + new_min = self.min_val + bin_width * start_bin + new_max = self.min_val + bin_width * (end_bin + 1) + return new_min, new_max + + def _upscale_histogram( + self, + histogram: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ): + # this turns the histogram into a more fine-coarsed histogram to reduce + # bin quantization errors + histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate + bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate) + mid_points_histogram = ( + torch.linspace( + orig_min, + orig_max, + self.bins * self.upsample_rate + 1, + device=orig_min.device, + )[:-1].to(histogram.device) + + 0.5 * bin_size + ) + boundaries_new_histogram = torch.linspace( + update_min, update_max, self.bins + 1, device=update_min.device + ).to(histogram.device) + # this maps the mid-poits of the histogram to the new histogram's space + bucket_assignments = ( + torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True) + - 1 + ) + # this then maps the histogram mid-points in the new space, weighted by the original histogram's values + # this is just the old histogram in the new histogram's space + + # In case due to numerical issues the values land higher/lower than the maximum/minimum + bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1 + bucket_assignments[bucket_assignments < 0] = 0 + + update_histogram = torch.bincount( + bucket_assignments, weights=histogram, minlength=self.bins + ) + return update_histogram + + def _combine_histograms( + self, + orig_hist: torch.Tensor, + orig_min: torch.Tensor, + orig_max: torch.Tensor, + update_hist: torch.Tensor, + update_min: torch.Tensor, + update_max: torch.Tensor, + ) -> torch.Tensor: + # If the new min and max are the same as the current min and max, + # we can just add the new histogram to the original histogram + if update_min == orig_min and update_max == orig_max: + return orig_hist + update_hist + + # If the orig hist only has one value (i.e., the min and max are the same) + # we can just add it into new histogram + if orig_min == orig_max: + bin_value = torch.sum(update_hist) + transformed_orig_hist = ( + torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] + * bin_value + ) + return transformed_orig_hist + update_hist + + # We assume the update_hist is already in the target range, we will map the orig_max to it + assert update_min <= orig_min + assert update_max >= orig_max + + # Now we need to turn the old_histogram, into the range of the new histogram + transformed_orig_hist = self._upscale_histogram( + orig_hist, + orig_min, + orig_max, + update_min, + update_max, + ) + + return update_hist + transformed_orig_hist + + def reset_histogram( + self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor + ) -> None: + self.min_val.resize_(min_val.shape) + self.min_val.copy_(min_val) + self.max_val.resize_(max_val.shape) + self.max_val.copy_(max_val) + assert min_val.numel() == 1 and max_val.numel() == 1, ( + "histogram min/max values must be scalar." + ) + new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type] + self.histogram.detach_().resize_(new_histogram.shape) + self.histogram.copy_(new_histogram) + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: # pyre-ignore[14] + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() + x_min, x_max = torch.aminmax(x) + # want to ignore torch.inf since we don't actually + # want to make our quantization range infinite + # and in practice those values will be clamped + if x_min == -torch.inf or x_max == torch.inf: + warnings.warn("torch.inf detected in input tensor, ignoring input") + x = x[x.abs() != torch.inf] + if x.numel() == 0: + return x_orig + x_min, x_max = torch.aminmax(x) + + current_min = self.min_val + current_max = self.max_val + + is_uninitialized = self.min_val == float("inf") or self.max_val == float("-inf") + if is_uninitialized: + self.reset_histogram(x, x_min, x_max) + else: + update_min, update_max = x_min, x_max + new_min = torch.min(current_min, update_min) + new_max = torch.max(current_max, update_max) + + # TODO: For some reason, this is required for it to pass torchscript test + # new_min and new_max should already have requires_grad set to False + new_min, new_max = new_min.detach(), new_max.detach() + update_histogram = torch.histc( + x, + self.bins, + min=new_min, # type: ignore[arg-type] + max=new_max, # type: ignore[arg-type] + ).to(self.histogram.device) + if new_min == current_min and new_max == current_max: + combined_histogram = self.histogram + update_histogram + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + else: + combined_histogram = self._combine_histograms( + self.histogram, + current_min, + current_max, + update_histogram, + new_min, + new_max, + ) + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + self.min_val.detach_().resize_(new_min.shape) + self.min_val.copy_(new_min) + self.max_val.detach_().resize_(new_max.shape) + self.max_val.copy_(new_max) + + return x_orig + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + is_uninitialized = self.min_val == float("inf") and self.max_val == float( + "-inf" + ) + if is_uninitialized: + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point " + ) + return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor( + [0], device=self.min_val.device.type + ) + assert self.bins == len(self.histogram), ( + "The number of bins in histogram should be equal to the number of bins " + "supplied while making this observer" + ) + + new_min, new_max = self._non_linear_param_search() + + return self._calculate_qparams(new_min, new_max) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "min_val"] = self.min_val + destination[prefix + "max_val"] = self.max_val + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 3: + # if min_val and max_val are not initialized, update their shape + # to account for the differences between v2 and v3 + min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" + if min_val_name in state_dict: + if state_dict[min_val_name].shape == torch.Size([0]): + state_dict[min_val_name] = torch.tensor(float("inf")) + if max_val_name in state_dict: + if state_dict[max_val_name].shape == torch.Size([0]): + state_dict[max_val_name] = torch.tensor(float("-inf")) + + local_state = ["min_val", "max_val"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + setattr(self, name, val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + +class FixedQParamsObserver(ObserverBase): + r""" + Observer that simulates quantize and dequantize with fixed + quantization parameters in training time. Only per tensor + quantization is supported. + + Args: + `scale` (float): fixed scale for the observer + `zero_point` (int): fixed zero point for the observer + `dtype`, `qscheme`, `quant_min`, `quant_max` + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + scale, + zero_point, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=0, + quant_max=255, + is_dynamic=False, + **kwargs, + ): + if is_dynamic: + raise NotImplementedError( + "FixedQParamsObserver doesn't support dynamic quantization" + ) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.quant_min = quant_min + self.quant_max = quant_max + self.register_buffer("scale", torch.tensor([scale], dtype=torch.float)) + self.register_buffer("zero_point", torch.tensor([zero_point], dtype=torch.int)) + self.dtype = dtype + self.qscheme = qscheme + + def forward(self, X): + return X + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + return self.scale, self.zero_point + + +class PlaceholderObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Can be used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + quant_min: minimum value in quantized domain (TODO: align behavior with other observers) + quant_max: maximum value in quantized domain + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + compute_dtype (deprecated): if set, marks the future quantize function to use + dynamic quantization instead of static quantization. + This field is deprecated, use `is_dynamic=True` instead. + is_dynamic: if True, the `quantize` function in the reference model + representation taking stats from this observer instance will + use dynamic quantization. + """ + + def __init__( + self, + dtype=torch.float32, + custom_op_name="", + compute_dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + eps=None, + is_dynamic=False, + ) -> None: + super().__init__(dtype=dtype, is_dynamic=is_dynamic) + if qscheme is None: + qscheme = torch.per_tensor_affine + if eps is None: + eps = torch.finfo(torch.float32).eps + + # dtype of input of the target operator, e.g. for dynamic quantization + # ops, the dtype will be float32 + self.dtype = dtype + self.qscheme = qscheme + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.custom_op = custom_op_name + # used for configuration of computation type for dynamic quantization + if compute_dtype: + is_dynamic = True + warnings.warn( + "Please use `is_dynamic` instead of `compute_dtype`. \ + `compute_dtype` will be deprecated in a future release \ + of PyTorch." + ) + + def forward(self, x): + return x + + @torch.jit.export + def extra_repr(self): + return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) + + +class RecordingObserver(ObserverBase): + r""" + The module is mainly for debug and records the tensor values during runtime. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + """ + + __annotations__ = {"tensor_val": list[Optional[torch.Tensor]]} + + def __init__(self, dtype=torch.quint8): + super().__init__(dtype=dtype, is_dynamic=False) + self.tensor_val = [] + + def forward(self, x): + self.tensor_val.append(x.clone()) + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for RecordingObserver" + ) + + @torch.jit.export + def get_tensor_value(self): + return self.tensor_val + + +class NoopObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Primarily used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: Quantized data type + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + """ + + def __init__(self, dtype=torch.float16, custom_op_name="") -> None: + super().__init__(dtype=dtype, is_dynamic=False) + self.dtype = dtype + self.custom_op = custom_op_name + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for NoopObserver" + ) + + +class ReuseInputObserver(ObserverBase): + r"""This observer is used when we want to reuse the observer from the operator + that produces the input Tensor, typically used for operators like reshape, e.g. + ``` + x0 = ... + x1 = x0.reshape() + ``` + if we configure x0 to be observed by some observer, let's say MinMaxObserver, + and reshape is configured with ReuseInputObserver, we'll reuse the observer instance + for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1. + + Note: this is only enabled in FX Graph Mode Quantization + """ + + def __init__(self) -> None: + super().__init__(torch.quint8, is_dynamic=False) + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): # type: ignore[override] + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for ReuseInputObserver" + ) + + +""" +# Experimental Affine Quantization Feature START +We plan to merge the following with torchao repo after we move pt2e flow to torchao +copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +""" +from dataclasses import dataclass +from enum import auto, Enum + + +class MappingType(Enum): + """How floating point number is mapped to integer number + + symmetric mapping means floating point range is symmetrically mapped to integer range + let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) + we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) + e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) + + SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin + and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating + smin and smax individually, there can be less round error on negative values, and no out-of-range + of all floating point values. + + asymmetric mapping means we just directly map the floating point range to integer range, + for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter + based on this mapping + e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) + """ + + SYMMETRIC = auto() + SYMMETRIC_NO_CLIPPING_ERR = auto() + ASYMMETRIC = auto() + + +class ZeroPointDomain(Enum): + """Enum that indicate whether zero_point is in integer domain or floating point domain + + integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) + float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + none domain: quantized_val = (float_val / scale) + """ + + INT = auto() + FLOAT = auto() + NONE = auto() + + +class TorchAODType(Enum): + """ + Placeholder for dtypes that do not exist in PyTorch core yet. + """ + + # torch.int1 to torch.int7 will be added to PyTorch 2.6 + # These will remain here for BC with older PyTorch versions + INT1 = auto() + INT2 = auto() + INT3 = auto() + INT4 = auto() + INT5 = auto() + INT6 = auto() + INT7 = auto() + + +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + + +@dataclass(frozen=True) +class PerBlock(Granularity): + """ + Represents per-block granularity in quantization. See + :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for + `block_size` + + Attributes: + block_size (Tuple[int, ...]): The size of each quantization group + """ + + block_size: tuple[int, ...] + + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calculates the quantization parameters + based off the entire tensor. + + """ + + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calculates different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + Attributes: + axis (int): The axis along which reduction is performed. + """ + + axis: int + + +@dataclass(frozen=True) +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calculates different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + + group_size: int + + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + + +class PerToken(Granularity): + """ + Represents per-token granularity in quantization. + + This granularity type calculates a different set of quantization parameters + for each token, which is represented as the last dimension of the tensor. + + For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens + with 4 elements each, and we will calculate 6 sets of quantization parameters, + one for each token. + + If the input tensor has only two dimensions, e.g. [8, 16], then this is + equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. + """ + + +def get_block_size( + input_shape: tuple[int, ...], granularity: Granularity +) -> tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity: The granularity type of the quantization + """ + assert isinstance(granularity, Granularity), ( + "Please provide an instance of Granularity, not subclass of it" + ) + if isinstance(granularity, PerTensor): + return input_shape + elif isinstance(granularity, PerAxis): + block_size = list(input_shape) + block_size[granularity.axis] = 1 + return tuple(block_size) + elif isinstance(granularity, PerRow): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert len(input_shape) == 2, ( + f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + ) + return (1, granularity.group_size) + elif isinstance(granularity, PerToken): + block_size = [1] * len(input_shape) + block_size[-1] = input_shape[-1] + return tuple(block_size) + raise ValueError(f"Unsupported Granularity: {granularity}") + + +class AffineQuantizedObserverBase(ABC, torch.nn.Module): + """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) + + Args: + `granularity` and `block_size`: The granularity of the quantization, + must specify at least one, if both are specified `block_size` takes precedence + Current supported granularity type are `PerTensor` and `PerAxis` + other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` + """ + + with_args = classmethod(_with_args) + + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + super().__init__() + assert granularity is not None, "granularity is None" + + self.mapping_type = mapping_type + self.target_dtype = target_dtype + self.granularity = granularity + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + # populatd during forward + self.block_size = None + self.original_dtype = None + + @abstractmethod + def forward(self, input: torch.Tensor) -> torch.Tensor: + """forward function should take the input tensor + and updates internal stats and return the original input Tensor + """ + + @abstractmethod + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + """Calculate quantization parameter based on the stats attached to the observer module + and returns a tuple of scale and zero_point Tensor + """ + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + """ + Converts the observer node in the graph into its quantized representation + + Args: + model: graph module to conver the observer node in + observer_node: the observer node to convert + """ + from torch.ao.quantization.fx.utils import create_getattr_from_value + + with model.graph.inserting_before(observer_node): + assert self.block_size is not None, "Expecting block_size to be populated" + assert self.original_dtype is not None, ( + "Expecting original_dtype to be populated" + ) + if hasattr(self, "is_dynamic") and self.is_dynamic: + choose_qparams_affine = model.graph.call_function( + torch.ops.pt2e_quant.choose_qparams_affine, + ( + observer_node.args[0], + self.mapping_type.name, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain.name, + ), + ) + scale_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 0) + ) + zero_point_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 1) + ) + else: + scale, zero_point = self.calculate_qparams() + scale_node = create_getattr_from_value( + model, model.graph, "_scale", scale + ) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point + ) + + q_node = model.graph.call_function( + torch.ops.pt2e_quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.pt2e_quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + + +def _is_observer_script_module(mod, obs_type_name): + """Returns true if given mod is an instance of Observer script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return obs_type_name in name + return False + + +# Experimental Affine Quantization Feature END + + +def _is_activation_post_process(module): + return isinstance( + module, + ( + torch.ao.quantization.ObserverBase, + torch.ao.quantization.FakeQuantizeBase, + AffineQuantizedObserverBase, + ), + ) or _is_observer_script_module(module, "quantization.observer") + + +def _is_per_channel_script_obs_instance(module): + if isinstance(module, torch.jit.RecursiveScriptModule): + return _is_observer_script_module( + module, "quantization.observer.PerChannelMinMaxObserver" + ) or _is_observer_script_module( + module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" + ) + return False + + +def get_observer_state_dict(mod): + r""" + Returns the state dict corresponding to the observer stats. + Traverse the model state_dict and extract out the stats. + """ + od = OrderedDict() + if isinstance(mod, torch.jit.RecursiveScriptModule): + for k, v in mod.state_dict().items(): + if "observer" in k: + od[k] = v + else: + # path for GraphModule and nn.Module (eager mode) + for k, v in mod.state_dict().items(): + if "activation_post_process" in k: + od[k] = v + od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] + return od + + +def load_observer_state_dict(mod, obs_dict): + r""" + Given input model and a state_dict containing model observer stats, + load the stats back into the model. The observer state_dict can be saved + using torch.ao.quantization.get_observer_state_dict + """ + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + for name, module in mod.named_modules(): + prefix = name + "." + if _is_activation_post_process(module): + if _is_per_channel_script_obs_instance(module): + # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. + # However this is not called when the module is scripted and we end up calling the default one in module.py + module._load_from_state_dict_script( + obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] + ) + else: + module._load_from_state_dict( + obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] + ) + for k in missing_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Missing keys for observer {k} in state_dict" + ) + for k in unexpected_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception( # noqa: TRY002 + f"Unexpected keys for observer {k} in state_dict" + ) + + +# Restrict activations to be in the range (0,127) +default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127) +""" +Default observer for static quantization, usually used for debugging. +""" + +default_placeholder_observer = PlaceholderObserver +""" +Default placeholder observer, usually used for quantization to torch.float16. +""" + +default_debug_observer = RecordingObserver +""" +Default debug-only observer. +""" + +default_weight_observer = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric +) +""" +Default weight observer. +""" + +weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127) +""" +Default histogram observer, usually used for PTQ. +""" + +default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +""" +Default per-channel weight observer, usually used on backends where per-channel +weight quantization is supported, such as `fbgemm`. +""" + +per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + quant_min=-127, + quant_max=127, + eps=2**-12, +) +""" +Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_dynamic_quant_observer = PlaceholderObserver.with_args( + dtype=torch.quint8, + quant_min=0, + quant_max=255, + is_dynamic=True, +) +""" +Default observer for dynamic quantization. +""" + +default_float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point. +""" + +default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( + dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point and 4 bit activations. +""" + +# TODO(future PR): remove these defaults and enforce activation functions +# to explicitly specify their output range +default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args( + scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255 +) +default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args( + scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255 +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer +default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer + +""" +Default observers for fixed qparams operations. +""" + +default_reuse_input_observer = ReuseInputObserver +""" +Default observer for operators like reshape that reuses the observer of input to +the operator +""" diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/qconfig.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/qconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..efee5302ad42ad8450c8746ebdf9f232b75fa47d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/qconfig.py @@ -0,0 +1,701 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections import namedtuple +from typing import Any, Optional, Union +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch.ao.quantization.fake_quantize import ( + default_dynamic_fake_quant, + default_embedding_fake_quant, + default_embedding_fake_quant_4bit, + default_fake_quant, + default_fused_act_fake_quant, + default_fused_per_channel_wt_fake_quant, + default_fused_wt_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + FakeQuantize, + FakeQuantizeBase, + fused_per_channel_wt_fake_quant_range_neg_127_to_127, + fused_wt_fake_quant_range_neg_127_to_127, + FusedMovingAvgObsFakeQuantize, +) + +from .observer import ( + _PartialWrapper, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_float_qparams_observer_4bit, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_reuse_input_observer, + default_weight_observer, + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + NoopObserver, + ObserverBase, + per_channel_weight_observer_range_neg_127_to_127, + PlaceholderObserver, + ReuseInputObserver, + weight_observer_range_neg_127_to_127, +) + + +__all__ = [ + "QConfig", + # TODO: deprecated, remove + "QConfigDynamic", + "default_qconfig", + "default_debug_qconfig", + "default_per_channel_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float16_static_qconfig", + "per_channel_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + "float_qparams_weight_only_qconfig_4bit", + "default_quint8_weight_qconfig", + "default_qat_qconfig", + "default_dynamic_qat_qconfig", + "default_weight_only_qconfig", + "default_activation_only_qconfig", + "default_qat_qconfig_v2", + "default_reuse_input_qconfig", + "default_symmetric_qnnpack_qconfig", + "default_per_channel_symmetric_qnnpack_qconfig", + "default_symmetric_qnnpack_qat_qconfig", + "default_per_channel_symmetric_qnnpack_qat_qconfig", + "default_embedding_qat_qconfig", + "default_embedding_qat_qconfig_4bit", + "get_default_qconfig", + "get_default_qat_qconfig", + "get_default_qconfig_dict", + "get_default_qat_qconfig_dict", + "QConfigAny", + "qconfig_equals", +] + + +class QConfig(namedtuple("QConfig", ["activation", "weight"])): + """ + Describes how to quantize a layer or a part of the network by providing + settings (observer classes) for activations and weights respectively. + + + Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization preparation function will instantiate observers multiple times for each of the layers. + + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfig( + activation=MinMaxObserver.with_args(dtype=torch.qint8), + weight=default_observer.with_args(dtype=torch.qint8), + ) + + """ + + __slots__ = () + + def __new__(cls, activation, weight): + # catch common mistakes + if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError( + "QConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +@deprecated( + "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", + category=FutureWarning, +) +class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): + """ + Describes how to dynamically quantize a layer or a part of the network by providing + settings (observer classes) for weights. + + It's like QConfig, but for dynamic quantization. + + Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of the layers. + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) + """ + + __slots__ = () + + def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): + # catch common mistakes + if isinstance(weight, nn.Module): + raise ValueError( + "QConfigDynamic received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed" + ) + return super().__new__(cls, activation, weight) + + +default_qconfig = QConfig(activation=default_observer, weight=default_weight_observer) +""" +Default qconfig configuration. +""" + +default_debug_qconfig = QConfig( + weight=default_weight_observer, activation=default_debug_observer +) +""" +Default qconfig configuration for debugging. +""" + +default_per_channel_qconfig = QConfig( + activation=default_observer, weight=default_per_channel_weight_observer +) +""" +Default qconfig configuration for per channel weight quantization. +""" + +default_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=default_weight_observer +) +""" +Default dynamic qconfig. +""" + +float16_dynamic_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with weights quantized to `torch.float16`. +""" + +float16_static_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float16), + weight=PlaceholderObserver.with_args(dtype=torch.float16), +) +""" +Dynamic qconfig with both activations and weights quantized to `torch.float16`. +""" + +per_channel_dynamic_qconfig = QConfig( + activation=default_dynamic_quant_observer, + weight=default_per_channel_weight_observer, +) +""" +Dynamic qconfig with weights quantized per channel. +""" + +float_qparams_weight_only_qconfig = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer +) +""" +Dynamic qconfig with weights quantized with a floating point zero_point. +""" + +float_qparams_weight_only_qconfig_4bit = QConfig( + activation=default_placeholder_observer, weight=default_float_qparams_observer_4bit +) + +default_qat_qconfig = QConfig( + activation=default_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for QAT. +""" + +default_dynamic_qat_qconfig = QConfig( + activation=default_dynamic_fake_quant, weight=default_weight_fake_quant +) +""" +Default qconfig for dynamic QAT. +""" + +default_weight_only_qconfig = QConfig( + activation=torch.nn.Identity, weight=default_weight_fake_quant +) +""" +Default qconfig for quantizing weights only. +""" + +default_activation_only_qconfig = QConfig( + activation=default_fake_quant, weight=torch.nn.Identity +) +""" +Default qconfig for quantizing activations only. +""" + +# QAT config that uses a fused observer + fake quant modules for optimized training performance. +# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified. +default_qat_qconfig_v2 = QConfig( + activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant +) +""" +Fused version of `default_qat_config`, has performance benefits. +""" + +default_reuse_input_qconfig = QConfig( + activation=default_reuse_input_observer, weight=NoopObserver +) +""" +Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape +""" + + +def get_default_qconfig(backend="x86", version=0): + """ + Returns the default PTQ qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_weight_observer, + ) + elif backend == "onednn": + if not torch.cpu._is_vnni_supported(): + warnings.warn( + "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " + "on CPU without Vector Neural Network Instruction support." + ) + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=False), + weight=default_per_channel_weight_observer, + ) + elif backend == "x86": + qconfig = QConfig( + activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer, + ) + else: + # won't reach + qconfig = default_qconfig + else: + raise AssertionError( + "Version number: " + + str(version) + + " in get_default_qconfig is not supported. Version number must be 0" + ) + + return qconfig + + +""" +Default, symmetric PTQ qconfig for the specified backend. And a per_channel +variant of the same. + +Symmetric here applies to signed weights with zero point = 0, and additional +value restrictions. The activations are also signed 8-bit integers with this +qconfig. + + * Once this change is merged [as of 3/17/22], with backend or qengine = + 'qnnpack', some quantized operators with this symmetric qconfig may use + operators from xnnpack library. + + ** Support to use xnnpack ops with `qnnpack` backed for asymmetric + qconfig (returned by get_default_qconfig()) is not available yet. + + * This qconfig uses signed activations and weights. Weights have added + restrictions such as zero point is forced to be 0, making the weights + symmetric, hence the name. And the 8-bit quantized values are + restricting to to [-127, +127], excluding -128. + + * xnnpack has a requantization scale value restriction, 0x1p-32 <= + requantization_scale < 256.0 where, `requantization_scale = (input_scale + * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value + of 256) is to prevent requantization_scale to go below xnnpack lower + threshold. +""" +default_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=weight_observer_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qconfig = QConfig( + activation=HistogramObserver.with_args( + dtype=torch.qint8, reduce_range=False, eps=2**-12 + ), + weight=per_channel_weight_observer_range_neg_127_to_127, +) + +default_embedding_qat_qconfig = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant, +) + +default_embedding_qat_qconfig_4bit = QConfig( + activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant_4bit, +) + +default_quint8_weight_qconfig = QConfig( + activation=HistogramObserver, weight=MinMaxObserver +) + + +def get_default_qat_qconfig(backend="x86", version=1): + """ + Returns the default QAT qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + * `version`: version, for backwards compatibility. Can be `None` or `1`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + # Histogram observer is too slow for quantization aware training + if version == 0: + if backend == "fbgemm": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "qnnpack": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_weight_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_per_channel_weight_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_per_channel_weight_fake_quant, + ) + else: + qconfig = default_qat_qconfig + # Use the fused observe + fake_quant modules for doing QAT. + elif version == 1: + if backend == "fbgemm": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "qnnpack": + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False, + ), + weight=default_fused_wt_fake_quant, + ) + elif backend == "onednn": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255 + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + elif backend == "x86": + qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True, + ), + weight=default_fused_per_channel_wt_fake_quant, + ) + else: + qconfig = default_qat_qconfig_v2 + else: + raise AssertionError( + "Version number: " + + str(version) + + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1" + ) + + return qconfig + + +""" +Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. +""" +default_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_wt_fake_quant_range_neg_127_to_127, +) + +default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2**-12, + ), + weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127, +) + +_default_fp32_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float32), + weight=PlaceholderObserver.with_args(dtype=torch.float32), +) + +_default_quint8_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.quint8), + # operators using this qconfig doesn't have weights + weight=None, +) + + +@deprecated( + "`torch.ao.quantization.get_default_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qconfig_dict(backend="x86", version=0): + return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() + + +@deprecated( + "`torch.ao.quantization.get_default_qat_qconfig_dict` is deprecated and will be removed in " + "a future version. Please use `torch.ao.quantization.get_default_qat_qconfig_mapping` instead.", + category=FutureWarning, +) +def get_default_qat_qconfig_dict(backend="x86", version=1): + return torch.ao.quantization.get_default_qat_qconfig_mapping( + backend, version + ).to_dict() + + +def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> None: + """ + Verifies that this `qconfig` is valid. + """ + if qconfig is None: + return + is_conv_transpose_mod = isinstance( + mod, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d), + ) + if is_conv_transpose_mod: + if qconfig.weight is None: + # for now, we assume that any qconfig for ConvTranspose without a weight is valid + return + example_observer = qconfig.weight() + is_per_channel = isinstance( + example_observer, + ( + torch.ao.quantization.PerChannelMinMaxObserver, + torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, + ), + ) + assert not is_per_channel, ( + "Per channel weight observer is not supported yet for ConvTranspose{n}d." + ) + + +QConfigAny = Optional[QConfig] +QConfigAny.__module__ = "torch.ao.quantization.qconfig" + + +def _add_module_to_qconfig_obs_ctr( + qconfig: QConfigAny, module: Optional[nn.Module] +) -> Any: + r"""This is a helper function for use in quantization prepare that updates a qconfig so that + the constructors stored in the qconfig will create observers on the same device that + 'module' is on. This is intended to be used when the qconfigs are propagated to each + module in order to avoid potential device alignment issues. + + Args: + qconfig: QConfig with obs constructors stored in activation and weight + module: module which the qconfig is related to + + Return: + qconfig: configured so that obs constructors set to construct on the same device as module + """ + + if module is None or qconfig is None or qconfig._fields != ("activation", "weight"): + return qconfig + + def get_factory_kwargs_based_on_module_device(): + assert isinstance(module, torch.nn.Module) + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + device = next(iter(devices)) if len(devices) > 0 else None + return None if device is None else {"device": device} + + def configure_constructor_to_put_obs_on_module_device(original_constructor): + try: + # check if constructor can accept factory_kwargs + check = original_constructor.with_args(factory_kwargs=None) + check() + return original_constructor.with_callable_args( + factory_kwargs=get_factory_kwargs_based_on_module_device + ) + except AttributeError: # qconfig doesn't have activation or weight + return original_constructor + except TypeError: # the class doesn't accept factory_kwargs argument + return original_constructor + + activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation) + weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight) + + return QConfig(activation, weight) + + +_ObserverOrFakeQuantizeConstructor = Union[ + _PartialWrapper, type[ObserverBase], type[FakeQuantizeBase] +] + + +def _obs_or_fq_ctr_equals( + obs_or_fq1: _ObserverOrFakeQuantizeConstructor, + obs_or_fq2: _ObserverOrFakeQuantizeConstructor, +): + if isinstance(obs_or_fq1, _PartialWrapper) and isinstance( + obs_or_fq2, _PartialWrapper + ): + return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) + return obs_or_fq1 == obs_or_fq2 + + +def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): + """ + Return whether the two partial wrappers are equal, + """ + # functools.partial has no __eq__ operator defined so '==' defaults to 'is' + obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) + obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) + keywords_equal = True + # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail + if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: + keywords_equal = keywords_equal and _obs_or_fq_ctr_equals( + obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"] + ) + obs_or_fq1_keywords.pop("observer") + obs_or_fq2_keywords.pop("observer") + keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords + return ( + obs_or_fq1.p.func == obs_or_fq2.p.func + and obs_or_fq1.p.args == obs_or_fq2.p.args + and keywords_equal + ) + + +def qconfig_equals(q1: QConfigAny, q2: QConfigAny): + """ + Returns `True` if `q1` equals `q2`, and `False` otherwise. + """ + if q1 is None or q2 is None: + return q1 == q2 + else: + assert q1 is not None and q2 is not None + try: + # Qconfig weight and activation can be either a partial wrapper, + # or an observer class. Special handling is required (above) for + # comparing partial wrappers. + activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) + weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) + return activation_same and weight_same + except AttributeError: + return q1 == q2 + + +def _activation_is_memoryless(qconfig: QConfig): + """ + Return whether the observer for activations defined in the given QConfig is memoryless. + This means a MovingAverage observer with averaging constant equal to 1. + """ + + def _is_memoryless(observer): + return ( + hasattr(observer, "averaging_constant") and observer.averaging_constant == 1 + ) + + act = qconfig.activation() + if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"): + return _is_memoryless(act.activation_post_process) + else: + return _is_memoryless(act) + + +def _is_reuse_input_qconfig(qconfig: Optional[QConfig]): + return ( + qconfig is not None + and isinstance(qconfig.activation(), ReuseInputObserver) + and isinstance(qconfig.weight(), NoopObserver) + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/qconfig_mapping.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/qconfig_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..bd34a6b8a1f4517888be968d67a30d125482e7e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/qconfig_mapping.py @@ -0,0 +1,381 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from collections import OrderedDict +from typing import Any, Callable, Union + +import torch + +from .fake_quantize import default_weight_fake_quant, FixedQParamsFakeQuantize +from .observer import ( + _PartialWrapper, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + default_placeholder_observer, + default_weight_observer, +) +from .qconfig import ( + default_quint8_weight_qconfig, + default_reuse_input_qconfig, + default_symmetric_qnnpack_qat_qconfig, + default_symmetric_qnnpack_qconfig, + get_default_qat_qconfig, + get_default_qconfig, + QConfig, + QConfigAny, +) + + +__all__ = [ + "get_default_qconfig_mapping", + "get_default_qat_qconfig_mapping", + "QConfigMapping", +] + + +# TODO: replace all usages with these constants +_GLOBAL_DICT_KEY = "" +_OBJECT_TYPE_DICT_KEY = "object_type" +_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" +_MODULE_NAME_DICT_KEY = "module_name" +_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" + +# TODO: derive this map from the BackendConfig +_FIXED_QPARAMS_OP_TO_OBSERVER: dict[Union[Callable, str], _PartialWrapper] = { + torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, + torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, + "hardsigmoid": default_fixed_qparams_range_0to1_observer, + "hardsigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer, + torch.sigmoid: default_fixed_qparams_range_0to1_observer, + "sigmoid": default_fixed_qparams_range_0to1_observer, + "sigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Softmax: default_fixed_qparams_range_0to1_observer, + torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer, + torch.tanh: default_fixed_qparams_range_neg1to1_observer, + "tanh": default_fixed_qparams_range_neg1to1_observer, + "tanh_": default_fixed_qparams_range_neg1to1_observer, +} + + +def _get_default_qconfig_mapping( + is_qat: bool, backend: str, version: int +) -> QConfigMapping: + """ + Return the default QConfigMapping for the given quantization type and backend. + """ + if is_qat: + qconfig = get_default_qat_qconfig(backend, version) + else: + qconfig = get_default_qconfig(backend, version) + default_weight = default_weight_fake_quant if is_qat else default_weight_observer + + # default_per_channel_weight_observer is not currently compatible with fbgemm backend + # so we have to modify the weight observer to default_weight_observer or another + # per tensor supported observer. + # see https://github.com/pytorch/pytorch/issues/47535 + if backend in ("fbgemm", "x86"): + qconfig_transpose = QConfig( + activation=qconfig.activation, weight=default_weight + ) + else: + qconfig_transpose = qconfig + + # currently layernorm only supports float weights + # we have to add this because otherwise there will be a extra quantize-dequantize pair + qconfig_layernorm = QConfig( + activation=qconfig.activation, weight=default_placeholder_observer + ) + + qconfig_mapping = ( + QConfigMapping() + .set_global(qconfig) + .set_object_type("reshape", default_reuse_input_qconfig) + .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) + .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) + .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) + .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) + .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) + .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) + .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) + ) + # Use special observers for ops with fixed qparams + fixed_qparams_observer_to_qconfig: dict[Any, QConfigAny] = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer] + else: + if is_qat: + activation = FixedQParamsFakeQuantize.with_args(observer=observer) + else: + activation = observer + fixed_qparams_qconfig = QConfig( + activation=activation, weight=default_weight + ) + fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig + qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig) + + # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. + # Need to be able to support fusion of ops with different qconfigs + + return qconfig_mapping + + +def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping: + """ + Return the default QConfigMapping for post training quantization. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + # TODO: add assert for backend choices + return _get_default_qconfig_mapping(False, backend, version) + + +def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping: + """ + Return the default QConfigMapping for quantization aware training. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + return _get_default_qconfig_mapping(True, backend, version) + + +def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qconfig + return _get_default_qconfig_mapping_with_default_qconfig( + False, "qnnpack", default_qconfig + ) + + +def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qat_qconfig + return _get_default_qconfig_mapping_with_default_qconfig( + True, "qnnpack", default_qconfig + ) + + +def _get_default_qconfig_mapping_with_default_qconfig( + is_qat: bool, + backend: str, + default_qconfig: QConfig, +) -> QConfigMapping: + """ + Return a QConfigMapping that uses the provided qconfig as the default QConfig. + """ + if is_qat: + qconfig_mapping = get_default_qat_qconfig_mapping(backend) + else: + qconfig_mapping = get_default_qconfig_mapping(backend) + qconfig_mapping.set_global(default_qconfig) + for pattern in qconfig_mapping.object_type_qconfigs.keys(): + if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: + qconfig_mapping.set_object_type(pattern, default_qconfig) + return qconfig_mapping + + +_QCONFIG_STYLE_ORDER: list[str] = [ + "global_qconfig", + "object_type_qconfigs", + "module_name_regex_qconfigs", + "module_name_qconfigs", + "module_name_object_type_order_qconfigs", +] + + +class QConfigMapping: + """ + Mapping from model ops to :class:`torch.ao.quantization.QConfig` s. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfig + + ``set_object_type`` : sets the QConfig for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string + + ``set_module_name`` : sets the QConfig for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Example usage:: + + qconfig_mapping = QConfigMapping() + .set_global(global_qconfig) + .set_object_type(torch.nn.Linear, qconfig1) + .set_object_type(torch.nn.ReLU, qconfig1) + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*", qconfig2) + .set_module_name("module1", qconfig1) + .set_module_name("module2", qconfig2) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3) + + """ + + def __init__(self) -> None: + # In increasing match priority: + self.global_qconfig: QConfigAny = None + self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = ( + OrderedDict() + ) + self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_object_type_order_qconfigs: OrderedDict[ + tuple[str, Callable, int], QConfigAny + ] = OrderedDict() + + def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping: + """ + Set the global (default) QConfig. + """ + self.global_qconfig = global_qconfig + return self + + def set_object_type( + self, object_type: Union[Callable, str], qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for a given module type, function, or method name. + If the QConfig for an existing object type was already set, the new QConfig will override the old one. + """ + self.object_type_qconfigs[object_type] = qconfig + return self + + def set_module_name_regex( + self, module_name_regex: str, qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for modules matching the given regex string. + + Regexes will be matched in the order in which they are registered through this method. + Thus, the caller should register more specific patterns first, e.g.:: + + qconfig_mapping = QConfigMapping() + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*bar.*", qconfig2) + .set_module_name_regex("foo.*", qconfig3) + + In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2, + and "foo.baz.relu" would match qconfig3. + + If the QConfig for an existing module name regex was already set, the new QConfig will override the + old one while preserving the order in which the regexes were originally registered. + """ + self.module_name_regex_qconfigs[module_name_regex] = qconfig + return self + + def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping: + """ + Set the QConfig for modules matching the given module name. + If the QConfig for an existing module name was already set, the new QConfig will override the old one. + """ + self.module_name_qconfigs[module_name] = qconfig + return self + + def set_module_name_object_type_order( + self, module_name: str, object_type: Callable, index: int, qconfig: QConfigAny + ) -> QConfigMapping: + """ + Set the QConfig for modules matching a combination of the given module name, object type, + and the index at which the module appears. + + If the QConfig for an existing (module name, object type, index) was already set, the new QConfig + will override the old one. + """ + self.module_name_object_type_order_qconfigs[ + (module_name, object_type, index) + ] = qconfig + return self + + def __repr__(self) -> str: + output = self.__class__.__name__ + " (" + for style_name in _QCONFIG_STYLE_ORDER: + output += f"\n {style_name}" + qconfigs = getattr(self, style_name) + if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0: + for key, qconfig in qconfigs.items(): + output += f"\n {key}: {qconfig}" + else: + output += f"\n {qconfigs}" + return output + "\n)" + + # TODO: remove this + def to_dict(self) -> dict[str, Any]: + """ + Convert this ``QConfigMapping`` to a dictionary with the following keys: + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are lists of tuples. + """ + return { + _GLOBAL_DICT_KEY: self.global_qconfig, + _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), + _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), + _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() + ], + } + + # TODO: remove this + @classmethod + def from_dict(cls, qconfig_dict: dict[str, Any]) -> QConfigMapping: + """ + Create a ``QConfigMapping`` from a dictionary with the following keys (all optional): + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are expected to be lists of tuples. + """ + conf = cls() + if _GLOBAL_DICT_KEY in qconfig_dict: + conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) + for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): + conf.set_object_type(object_type, qconfig) + for module_name_regex, qconfig in qconfig_dict.get( + _MODULE_NAME_REGEX_DICT_KEY, [] + ): + conf.set_module_name_regex(module_name_regex, qconfig) + for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): + conf.set_module_name(module_name, qconfig) + for module_name, object_type, index, qconfig in qconfig_dict.get( + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, [] + ): + conf.set_module_name_object_type_order( + module_name, object_type, index, qconfig + ) + return conf diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quant_type.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quant_type.py new file mode 100644 index 0000000000000000000000000000000000000000..18488d7f9ccba604ca8f1df7ea0ef4a88546d63e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quant_type.py @@ -0,0 +1,35 @@ +import enum + + +__all__ = [ + "QuantType", +] + + +# Quantization type (dynamic quantization, static quantization). +# Should match the c++ enum in quantization_type.h +class QuantType(enum.IntEnum): + DYNAMIC = 0 + STATIC = 1 + QAT = 2 + WEIGHT_ONLY = 3 + + +_quant_type_to_str = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", +} + + +# TODO: make this private +def _get_quant_type_to_str(quant_type: QuantType) -> str: + return _quant_type_to_str[quant_type] + + +def _quant_type_from_str(name: str) -> QuantType: + for quant_type, s in _quant_type_to_str.items(): + if name == s: + return quant_type + raise ValueError(f"Unknown QuantType name '{name}'") diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantization_mappings.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantization_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..e22fba05bbc99ce10ea275bff7b6db1b005ad160 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantization_mappings.py @@ -0,0 +1,365 @@ +import copy +from typing import Any, Callable, Optional, Union + +import torch +import torch.ao.nn as ao_nn +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.qat as nnqat +import torch.ao.nn.qat.dynamic as nnqatd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized.reference as nnqr + +# Because `torch.ao.nn` uses lazy imports, we need to make +# sure we import the contents explicitly here. +import torch.ao.nn.sparse +import torch.nn.functional as F +from torch import nn +from torch.ao.quantization.fake_quantize import ( + default_fixed_qparams_range_0to1_fake_quant, + default_fixed_qparams_range_neg1to1_fake_quant, +) +from torch.ao.quantization.stubs import DeQuantStub, QuantStub +from torch.ao.quantization.utils import get_combined_dict +from torch.nn.utils.parametrize import type_before_parametrizations + + +__all__ = [ + "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_QAT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", + "DEFAULT_MODULE_TO_ACT_POST_PROCESS", + "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS", + "no_observer_set", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_static_quant_module_class", + "get_dynamic_quant_module_class", + "get_default_qat_module_mappings", + "get_embedding_qat_module_mappings", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_default_float_to_quantized_operator_mappings", + "get_quantized_operator", +] + +# Default map for swapping float module to reference quantized modules +DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.Linear: nnqr.Linear, + nn.Conv1d: nnqr.Conv1d, + nn.Conv2d: nnqr.Conv2d, + nn.Conv3d: nnqr.Conv3d, + nn.ConvTranspose1d: nnqr.ConvTranspose1d, + nn.ConvTranspose2d: nnqr.ConvTranspose2d, + nn.ConvTranspose3d: nnqr.ConvTranspose3d, + nn.Embedding: nnqr.Embedding, + nn.EmbeddingBag: nnqr.EmbeddingBag, + nn.GRUCell: nnqr.GRUCell, + nn.LSTMCell: nnqr.LSTMCell, + nn.RNNCell: nnqr.RNNCell, + nn.LSTM: nnqr.LSTM, +} + +# Default map for swapping float module to quantized ones +DEFAULT_STATIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nn.Dropout: nnq.Dropout, + nn.Conv1d: nnq.Conv1d, + nn.Conv2d: nnq.Conv2d, + nn.Conv3d: nnq.Conv3d, + nn.ConvTranspose1d: nnq.ConvTranspose1d, + nn.ConvTranspose2d: nnq.ConvTranspose2d, + nn.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.Embedding: nnq.Embedding, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.GroupNorm: nnq.GroupNorm, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.LeakyReLU: nnq.LeakyReLU, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear, + nn.Linear: nnq.Linear, + nn.ReLU6: nnq.ReLU6, + nn.PReLU: nnq.PReLU, + # Wrapper Modules: + nnq.FloatFunctional: nnq.QFunctional, + # Intrinsic modules: + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, + nni.ConvReLU1d: nniq.ConvReLU1d, + nni.ConvReLU2d: nniq.ConvReLU2d, + nni.ConvReLU3d: nniq.ConvReLU3d, + nni.ConvAdd2d: nniq.ConvAdd2d, + nni.ConvAddReLU2d: nniq.ConvAddReLU2d, + nni.LinearReLU: nniq.LinearReLU, + nni.LinearLeakyReLU: nniq.LinearLeakyReLU, + nni.LinearTanh: nniq.LinearTanh, + nniqat.ConvBn1d: nnq.Conv1d, + nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBn3d: nnq.Conv3d, + nniqat.ConvBnReLU1d: nniq.ConvReLU1d, + nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + nniqat.ConvBnReLU3d: nniq.ConvReLU3d, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.ConvReLU3d: nniq.ConvReLU3d, + nniqat.LinearReLU: nniq.LinearReLU, + nniqat.LinearBn1d: nnq.Linear, + # QAT modules: + nnqat.Linear: nnq.Linear, + nnqat.Conv2d: nnq.Conv2d, + nnqat.Conv3d: nnq.Conv3d, +} + +# Default map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Conv2d: nnqat.Conv2d, + nn.Conv3d: nnqat.Conv3d, + nn.Linear: nnqat.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear, + # Intrinsic modules: + nni.ConvBn1d: nniqat.ConvBn1d, + nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBn3d: nniqat.ConvBn3d, + nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, + nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, + nni.ConvBnReLU3d: nniqat.ConvBnReLU3d, + nni.ConvReLU2d: nniqat.ConvReLU2d, + nni.ConvReLU3d: nniqat.ConvReLU3d, + nni.LinearReLU: nniqat.LinearReLU, + nni.LinearBn1d: nniqat.LinearBn1d, +} + +# Default map for swapping dynamic modules +DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.GRUCell: nnqd.GRUCell, + nn.Linear: nnqd.Linear, + nnqatd.Linear: nnqd.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, + nn.GRU: nnqd.GRU, + nn.LSTMCell: nnqd.LSTMCell, + nn.RNNCell: nnqd.RNNCell, + nni.LinearReLU: nniqd.LinearReLU, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.Embedding: nnq.Embedding, + # Don't want to enable these by default because the numerical + # accuracy is poor compared to other dynamic ops + # nn.Conv1d: nnqd.Conv1d, + # nn.Conv2d: nnqd.Conv2d, + # nn.Conv3d: nnqd.Conv3d, + # nn.ConvTranspose1d: nnqd.ConvTranspose1d, + # nn.ConvTranspose2d: nnqd.ConvTranspose2d, + # nn.ConvTranspose3d: nnqd.ConvTranspose3d, +} + +# Allowlist for propagating the qconfig +_INCLUDE_QCONFIG_PROPAGATE_LIST: set[Callable] = { + nn.Sequential, +} + +# Default mapping from floating point function or torch ops to quantized ops +# TODO: merge with default static mapping +DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS: dict[Union[Callable, str], Callable] = { + F.elu: torch.ops.quantized.elu, + F.hardswish: torch.ops.quantized.hardswish, + F.instance_norm: torch.ops.quantized.instance_norm, + F.layer_norm: torch.ops.quantized.layer_norm, + F.leaky_relu: torch.ops.quantized.leaky_relu, + F.dropout: torch.ops.quantized.dropout, +} + +# mapping from module to output activation post process class +DEFAULT_MODULE_TO_ACT_POST_PROCESS: dict[Callable, Callable] = { + nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Softmax: default_fixed_qparams_range_0to1_fake_quant, + nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant, +} + +# Default map for swapping float module to static sparse quantized ones +DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.Linear +} + +# Default map for swapping float module to dynamic sparse quantized ones +DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS: dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.dynamic.Linear +} + + +def no_observer_set() -> set[Any]: + r"""These modules cannot have observers inserted by default.""" + no_observers = {nn.quantizable.LSTM, nn.quantizable.MultiheadAttention} + return no_observers + + +def get_default_static_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training static quantization""" + return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + + +def get_default_static_quant_reference_module_mappings() -> dict[Callable, Any]: + """Get reference module mapping for post training static quantization""" + return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS) + + +def get_embedding_static_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping, including mapping for embedding QAT""" + mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag + mapping[nnqat.Embedding] = nnq.Embedding + return mapping + + +def get_default_static_sparse_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training static sparse quantization""" + return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS) + + +def get_static_quant_module_class( + float_module_class: Callable, + additional_static_quant_mapping: Optional[dict[Callable, Any]] = None, + is_reference: bool = False, +) -> Any: + r"""n Get the statically quantized module class corresponding to + the floating point module class + """ + if additional_static_quant_mapping is None: + additional_static_quant_mapping = {} + all_mappings = get_combined_dict( + DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS + if is_reference + else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, + additional_static_quant_mapping, + ) + static_quant_module_class = all_mappings.get(float_module_class, None) + assert static_quant_module_class is not None, ( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) + return copy.deepcopy(static_quant_module_class) + + +def get_dynamic_quant_module_class( + float_module_class: Callable, + additional_dynamic_quant_mapping: Optional[dict[Callable, Any]] = None, +) -> Any: + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict( + DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping + ) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + assert dynamic_quant_module_class is not None, ( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) + return copy.deepcopy(dynamic_quant_module_class) + + +def get_default_qat_module_mappings() -> dict[Callable, Any]: + """Get default module mapping for quantization aware training""" + return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + + +def get_embedding_qat_module_mappings() -> dict[Callable, Any]: + """Get module mapping for quantization aware training + This is includes default values in addition to + enabling qat for embeddings. + """ + mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag + mapping[nn.Embedding] = nnqat.Embedding + return mapping + + +def get_default_dynamic_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training dynamic quantization""" + return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS + + +def get_default_dynamic_sparse_quant_module_mappings() -> dict[Callable, Any]: + """Get module mapping for post training dynamic sparse quantization""" + return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS + + +def get_default_qconfig_propagation_list() -> set[Callable]: + """Get the default list of module types that we'll attach qconfig + attribute to in prepare + """ + QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST) + + +def get_default_compare_output_module_list() -> set[Callable]: + """Get list of module class types that we will record output + in numeric suite + """ + NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.values()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST) + + +def get_default_float_to_quantized_operator_mappings() -> dict[ + Union[Callable, str], Callable +]: + return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS) + + +# TODO: merge with get_static_quant_module_class +def get_quantized_operator(float_op: Union[Callable, str]) -> Callable: + """Get the quantized operator corresponding to the float operator""" + quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) + assert quantized_op is not None, ( + f"Operator {str(float_op)} does not have corresponding quantized op" + ) + return quantized_op + + +def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]: + r"""Get the special activation post process for `module`, this has + higher priority than the activation post process in `qconfig` + e.g. + input: torch.nn.Sigmoid + output: default_affine_fixed_qparam_fake_quant + """ + return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get( + type_before_parametrizations(module), None + ) + + +def _has_special_act_post_process(module: torch.nn.Module) -> bool: + return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..b85618a16331fe2752be746316a9a35c90ee3266 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize.py @@ -0,0 +1,819 @@ +# mypy: allow-untyped-defs +import copy +import inspect +import itertools +import typing_extensions +import warnings + +import torch +import torch.ao.nn.quantized as nnq +import torch.nn as nn +from torch.ao.nn.intrinsic import _FusedModule +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig import ( + _activation_is_memoryless, + _add_module_to_qconfig_obs_ctr, + default_dynamic_qconfig, + float16_dynamic_qconfig, + float_qparams_weight_only_qconfig, + float_qparams_weight_only_qconfig_4bit, +) +from torch.ao.quantization.quantization_mappings import ( + _get_special_act_post_process, + _has_special_act_post_process, + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + get_default_static_quant_module_mappings, + get_default_static_quant_reference_module_mappings, + no_observer_set, +) +from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + DEPRECATION_WARNING, + get_qparam_dict, + has_no_children_ignoring_parametrizations, +) + + +__all__ = [ + "get_default_custom_config_dict", + "propagate_qconfig_", + "add_quant_dequant", + "prepare", + "quantize", + "quantize_dynamic", + "prepare_qat", + "quantize_qat", + "convert", + "swap_module", +] + + +# TODO remove this once BC is no longer required to avoid a SEV +is_activation_post_process = _is_activation_post_process + + +_DEFAULT_CUSTOM_CONFIG_DICT = { + "float_to_observed_custom_module_class": { + nn.LSTM: nn.quantizable.LSTM, + nn.MultiheadAttention: nn.quantizable.MultiheadAttention, + }, + "observed_to_quantized_custom_module_class": { + nn.quantizable.LSTM: nn.quantized.LSTM, + nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, + }, +} + + +def get_default_custom_config_dict(): + r"""Defines the default custom config dict.""" + return _DEFAULT_CUSTOM_CONFIG_DICT + + +def _propagate_qconfig_helper( + module, + qconfig_dict, + qconfig_parent=None, + prefix="", + prepare_custom_config_dict=None, +): + r"""This is a helper function for `propagate_qconfig_` + + Args: + module: input module + qconfig_dict: dictionary that maps from name of submodule to quantization + configuration + qconfig_parent: quantization config of parent module, we will fallback to + this config when there is no specified config for current + module + prefix: corresponding prefix of the current module, used as key in + qconfig_dict + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + + module_qconfig = qconfig_dict.get( + type_before_parametrizations(module), qconfig_parent + ) + module_qconfig = qconfig_dict.get(prefix, module_qconfig) + module_qconfig = getattr(module, "qconfig", module_qconfig) + + torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) + + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) + module.qconfig = qconfig_with_device_check + + for name, child in module.named_children(): + module_prefix = prefix + "." + name if prefix else name + # do no not propagate qconfig to child if child is non traceable + if prepare_custom_config_dict is None or not ( + name in prepare_custom_config_dict.get("non_traceable_module_name", []) + or type(child) + in prepare_custom_config_dict.get("non_traceable_module_class", []) + ): + _propagate_qconfig_helper( + child, qconfig_dict, qconfig_with_device_check, module_prefix + ) + + +def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): + r"""Propagate qconfig through the module hierarchy and assign `qconfig` + attribute on each leaf module + + Args: + module: input module + qconfig_dict: dictionary that maps from name or type of submodule to + quantization configuration, qconfig applies to all submodules of a + given module unless qconfig for the submodules are specified (when + the submodule already has qconfig attribute) + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + if qconfig_dict is None: + qconfig_dict = {} + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + _propagate_qconfig_helper( + module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict + ) + + +def _observer_forward_hook(self, input, output): + r"""Forward hook that calls observer on the output""" + return self.activation_post_process(output) + + +def _observer_forward_pre_hook(self, input): + r"""Forward pre hook that calls observer on the output""" + return self.activation_post_process(input[0]) + + +def _register_activation_post_process_hook(module, pre_hook=False): + assert hasattr(module, "activation_post_process"), ( + "Expect activation_post_process attribute already attached to the module" + ) + if pre_hook: + module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) + else: + module.register_forward_hook(_observer_forward_hook, prepend=True) + + +def _add_observer_( + module, + qconfig_propagation_list=None, + non_leaf_module_list=None, + device=None, + custom_module_class_mapping=None, +): + r"""Add observer for the leaf child of the module. + + This function insert observer module to all leaf child module that + has a valid qconfig attribute. + + Args: + module: input module with qconfig attributes for all the leaf modules that we want to quantize + qconfig_propagation_list: a list of quantizable modules that will have observers added to them + if they are leaf nodes + device: parent device, if any + non_leaf_module_list: list of non-leaf modules we want to add observer + + Return: + None, module is modified inplace with added observer modules and forward_hooks + """ + if qconfig_propagation_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + + if custom_module_class_mapping is None: + custom_module_class_mapping = {} + + # respect device affinity when adding observers + if device is None: + devices = _get_unique_devices_(module) + assert len(devices) <= 1, ( + f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + + def get_activation_post_process(qconfig, device, special_act_post_process=None): + activation = ( + qconfig.activation() + if special_act_post_process is None + else special_act_post_process() + ) + if device is not None: + activation.to(device) + return activation + + def needs_observation(m): + return hasattr(m, "qconfig") and m.qconfig is not None + + def insert_activation_post_process(m, special_act_post_process=None): + """Adds an activation post process module and register + a pre or post hook that calls the module + """ + # We don't insert observer/fake_quantize for DeQuantStub + if needs_observation(m) and not isinstance(m, DeQuantStub): + # observer and hook will be gone after we swap the module + m.add_module( + "activation_post_process", + get_activation_post_process( + m.qconfig, device, special_act_post_process + ), + ) + # Register observer as the first entry in the hook list + # All post forward hooks are preserved and will be executed after the observer before convert + _register_activation_post_process_hook( + m, pre_hook=_activation_is_memoryless(m.qconfig) + ) + + for name, child in module.named_children(): + # TODO remove Dropout special after codebase stable + if type_before_parametrizations(child) in [nn.Dropout]: + continue + elif issubclass( + type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) + ): + if needs_observation(child): + assert hasattr(child, "activation_post_process"), ( + f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" + ) + child.activation_post_process = get_activation_post_process( + child.qconfig, device + ) + elif isinstance(child, _FusedModule): + # activation_post_process are now added directly to nn.Sequential/_FusedModule + if needs_observation(child): + insert_activation_post_process(child) + elif ( + non_leaf_module_list is not None + and type_before_parametrizations(child) in non_leaf_module_list + ): + if needs_observation(child): + insert_activation_post_process(child) + elif _has_special_act_post_process(child): + special_act_post_process = _get_special_act_post_process(child) + insert_activation_post_process(child, special_act_post_process) + elif ( + needs_observation(child) + and type_before_parametrizations(child) in custom_module_class_mapping + ): + observed_class = custom_module_class_mapping[ + type_before_parametrizations(child) + ] + observed_child = observed_class.from_float(child) + setattr(module, name, observed_child) + # TODO: These are the modules that cannot be observed + # Once there are more, we should move them to a separate list + if not issubclass(observed_class, tuple(no_observer_set())): + insert_activation_post_process(observed_child) + else: + _add_observer_( + child, + qconfig_propagation_list, + non_leaf_module_list, + device, + custom_module_class_mapping, + ) + + # Insert observers only for leaf nodes, note that this observer is for + # the output of the module, for input QuantStub will observe them + if ( + has_no_children_ignoring_parametrizations(module) + and not isinstance(module, torch.nn.Sequential) + and type_before_parametrizations(module) in qconfig_propagation_list + ): + insert_activation_post_process(module) + # This is a special case for AdaRound eager mode + # AdaRound contains weight_fake_quant to be propagated from API to convert + # leaf node check with a number of children looks naive assumption that blocks + # Adding an exception case for AdaRound + if ( + hasattr(module, "weight_fake_quant") + and not isinstance(module, torch.nn.Sequential) + and type_before_parametrizations(module) in qconfig_propagation_list + ): + insert_activation_post_process(module) + + +def _get_unique_devices_(module): + return {p.device for p in module.parameters() if p.device.type != "meta"} | { + p.device for p in module.buffers() if p.device.type != "meta" + } + + +def add_quant_dequant(module): + r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig + Note that this function will modify the children of module inplace and it + can return a new module which wraps the input module as well. + + Args: + module: input module with qconfig attributes for all the leaf modules + that we want to quantize + + Return: + Either the inplace modified module with submodules wrapped in + `QuantWrapper` based on qconfig or a new `QuantWrapper` module which + wraps the input module, the latter case only happens when the input + module is a leaf module and we want to quantize it. + """ + if ( + has_no_children_ignoring_parametrizations(module) + and hasattr(module, "qconfig") + and module.qconfig + ): + return QuantWrapper(module) + + for name, child in module.named_children(): + module._modules[name] = add_quant_dequant(child) + return module + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare( + model, + inplace=False, + allow_list=None, + observer_non_leaf_module_list=None, + prepare_custom_config_dict=None, +): + r"""Prepares a copy of the model for quantization calibration or quantization-aware training. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + The model will be attached with observer or fake quant modules, and qconfig + will be propagated. + + Args: + `model`: input model to be modified in-place + `inplace`: carry out model transformations in-place, the original module is mutated + `allow_list`: list of quantizable modules + `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer + `prepare_custom_config_dict`: customization configuration dictionary for prepare function + + .. code-block:: python + + # Example of prepare_custom_config_dict: + prepare_custom_config_dict = { + # user will manually define the corresponding observed + # module class which has a from_float class method that converts + # float custom module to observed custom module + "float_to_observed_custom_module_class": {CustomModule: ObservedCustomModule} + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare") + if prepare_custom_config_dict is None: + prepare_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) + + if not inplace: + model = copy.deepcopy(model) + + # TODO: remove allow_list + qconfig_propagation_list = allow_list + if allow_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + propagate_qconfig_(model, qconfig_dict=None) + + # sanity check common API misusage + if not any(hasattr(m, "qconfig") and m.qconfig for m in model.modules()): + warnings.warn( + "None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules" + ) + + _add_observer_( + model, + qconfig_propagation_list, + observer_non_leaf_module_list, + custom_module_class_mapping=custom_module_class_mapping, + ) + return model + + +def _remove_activation_post_process(module): + # TODO: maybe we should change activation_post_process to _activation_post_process + # to prevent it from being used by user + if hasattr(module, "activation_post_process") and _is_activation_post_process( + module.activation_post_process + ): + delattr(module, "activation_post_process") + + # remove activation_post_process pre and post hooks + def remove_hooks(pre_hook=False): + hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks + observer_hook = ( + _observer_forward_pre_hook if pre_hook else _observer_forward_hook + ) + handle_ids_to_remove = set() + for handle_id, hook_fn in hook_map.items(): + if hook_fn is observer_hook: + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + hook_map.pop(handle_id) + + remove_hooks(pre_hook=True) + remove_hooks(pre_hook=False) + + +# TODO: rename to something more general +def _remove_qconfig(module): + r"""Clean up the qconfig left in the module so that new qconfig can be + propagated. + + Args: + module: module to be cleaned up + """ + for child in module.children(): + _remove_qconfig(child) + + if hasattr(module, "qconfig"): + del module.qconfig + + _remove_activation_post_process(module) + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize(model, run_fn, run_args, mapping=None, inplace=False): + r"""Quantize the input float model with post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + model: input float model + run_fn: a calibration function for calibrating the prepared model + run_args: positional arguments for `run_fn` + inplace: carry out model transformations in-place, the original module is mutated + mapping: correspondence between original module types and quantized counterparts + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize") + if mapping is None: + mapping = get_default_static_quant_module_mappings() + if not inplace: + model = copy.deepcopy(model) + model.eval() + prepare(model, inplace=True) + run_fn(model, *run_args) + convert(model, mapping, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize_dynamic( + model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False +): + r"""Converts a float model to dynamic (i.e. weights-only) quantized model. + + Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. + + For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization + by default is performed for layers with large weights size - i.e. Linear and RNN variants. + + Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. + If `qconfig` is provided, the `dtype` argument is ignored. + + Args: + model: input model + qconfig_spec: Either: + + - A dictionary that maps from name or type of submodule to quantization + configuration, qconfig applies to all submodules of a given + module unless qconfig for the submodules are specified (when the + submodule already has qconfig attribute). Entries in the dictionary + need to be QConfig instances. + + - A set of types and/or submodule names to apply dynamic quantization to, + in which case the `dtype` argument is used to specify the bit-width + + inplace: carry out model transformations in-place, the original module is mutated + mapping: maps type of a submodule to a type of corresponding dynamically quantized version + with which the submodule needs to be replaced + + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") + if qconfig_spec is None: + if dtype == torch.qint8: + qconfig_spec = { + nn.Linear: default_dynamic_qconfig, + nn.LSTM: default_dynamic_qconfig, + nn.GRU: default_dynamic_qconfig, + nn.LSTMCell: default_dynamic_qconfig, + nn.RNNCell: default_dynamic_qconfig, + nn.GRUCell: default_dynamic_qconfig, + } + elif dtype == torch.float16: + qconfig_spec = { + nn.Linear: float16_dynamic_qconfig, + nn.LSTM: float16_dynamic_qconfig, + nn.GRU: float16_dynamic_qconfig, + nn.LSTMCell: float16_dynamic_qconfig, + nn.RNNCell: float16_dynamic_qconfig, + nn.GRUCell: float16_dynamic_qconfig, + } + elif dtype == torch.quint8: + qconfig_spec = { + nn.EmbeddingBag: float_qparams_weight_only_qconfig, + nn.Embedding: float_qparams_weight_only_qconfig, + } + elif dtype == torch.quint4x2: + qconfig_spec = { + nn.EmbeddingBag: float_qparams_weight_only_qconfig_4bit, + } + else: + raise ValueError( + f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please" + ) + elif isinstance(qconfig_spec, set): + if dtype is torch.qint8: + default_qconfig = default_dynamic_qconfig + elif dtype is torch.float16: + default_qconfig = float16_dynamic_qconfig + elif dtype is torch.quint8: + default_qconfig = float_qparams_weight_only_qconfig + elif dtype is torch.quint4x2: + default_qconfig = float_qparams_weight_only_qconfig_4bit + else: + raise RuntimeError( + "Unknown dtype specified for quantize_dynamic: ", str(dtype) + ) + qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) + + if mapping is None: + mapping = get_default_dynamic_quant_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + model.eval() + propagate_qconfig_(model, qconfig_spec) + convert(model, mapping, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat(model, mapping=None, inplace=False): + r""" + Prepares a copy of the model for quantization calibration or + quantization-aware training and converts it to quantized version. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + Args: + model: input model to be modified in-place + mapping: dictionary that maps float modules to quantized modules to be + replaced. + inplace: carry out model transformations in-place, the original module + is mutated + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") + assert model.training, "prepare_qat only works on models in training mode" + if mapping is None: + mapping = get_default_qat_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + + propagate_qconfig_(model, qconfig_dict=None) + convert(model, mapping=mapping, inplace=True, remove_qconfig=False) + prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def quantize_qat(model, run_fn, run_args, inplace=False): + r"""Do quantization aware training and output a quantized model + + Args: + model: input model + run_fn: a function for evaluating the prepared model, can be a + function that simply runs the prepared model or a training + loop + run_args: positional arguments for `run_fn` + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") + if not inplace: + model = copy.deepcopy(model) + model.train() + prepare_qat(model, inplace=True) + run_fn(model, *run_args) + convert(model, inplace=True) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert( + module, + mapping=None, + inplace=False, + remove_qconfig=True, + is_reference=False, + convert_custom_config_dict=None, + use_precomputed_fake_quant=False, +): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class. And remove qconfig at the + end if remove_qconfig is set to True. + + Args: + `module`: prepared and calibrated module + `mapping`: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + `inplace`: carry out model transformations in-place, the original module + is mutated + `convert_custom_config_dict`: custom configuration dictionary for convert function + `use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant + + .. code-block:: python + + # Example of convert_custom_config_dict: + convert_custom_config_dict = { + # user will manually define the corresponding quantized + # module class which has a from_observed class method that converts + # observed custom module to quantized custom module + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: QuantizedCustomModule + } + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.convert") + if not inplace: + module = copy.deepcopy(module) + _convert( + module, + mapping, + inplace=True, + is_reference=is_reference, + convert_custom_config_dict=convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant, + ) + if remove_qconfig: + _remove_qconfig(module) + return module + + +def _convert( + module, + mapping=None, + inplace=False, + is_reference=False, + convert_custom_config_dict=None, + use_precomputed_fake_quant=False, +): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class + + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + is_reference: a flag to enable quantized reference module + use_precomputed_fake_quant: a flag to enable use of precomputed fake quant + + """ + if mapping is None: + mapping = ( + get_default_static_quant_reference_module_mappings() + if is_reference + else get_default_static_quant_module_mappings() + ) + if convert_custom_config_dict is None: + convert_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = convert_custom_config_dict.get( + "observed_to_quantized_custom_module_class", {} + ) + + if not inplace: + module = copy.deepcopy(module) + reassign = {} + for name, mod in module.named_children(): + # both fused modules and observed custom modules are + # swapped as one unit + if ( + not isinstance(mod, _FusedModule) + and type_before_parametrizations(mod) not in custom_module_class_mapping + ): + _convert( + mod, + mapping, + True, # inplace + is_reference, + convert_custom_config_dict, + use_precomputed_fake_quant=use_precomputed_fake_quant, + ) + reassign[name] = swap_module( + mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + +def swap_module( + mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False +): + r"""Swaps the module if it has a quantized counterpart and it has an + `observer` attached. + + Args: + mod: input module + mapping: a dictionary that maps from nn module to nnq module + + Return: + The corresponding quantized module of `mod` + """ + new_mod = mod + if hasattr(mod, "qconfig") and mod.qconfig is not None: + swapped = False + if type_before_parametrizations(mod) in custom_module_class_mapping: + new_mod = custom_module_class_mapping[ + type_before_parametrizations(mod) + ].from_observed(mod) + swapped = True + elif type_before_parametrizations(mod) in mapping: + qmod = mapping[type_before_parametrizations(mod)] + if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE: + assert mod.qconfig is not None + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + weight_qparams = get_qparam_dict(weight_post_process) + new_mod = qmod.from_float(mod, weight_qparams) + else: + sig = inspect.signature(qmod.from_float) + if "use_precomputed_fake_quant" in sig.parameters: + new_mod = qmod.from_float( + mod, use_precomputed_fake_quant=use_precomputed_fake_quant + ) + else: + new_mod = qmod.from_float(mod) + swapped = True + + if swapped: + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + if hook_fn is not _observer_forward_hook: + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + devices = _get_unique_devices_(mod) + assert len(devices) <= 1 or ( + len(devices) == 2 and torch.device("meta") in devices + ), ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + return new_mod + + +def _get_observer_dict(mod, target_dict, prefix=""): + r"""Traverse the modules and save all observers into dict. + This is mainly used for quantization accuracy debug + Args: + mod: the top module we want to save all observers + prefix: the prefix for the current module + target_dict: the dictionary used to save all the observers + """ + + def get_prefix(prefix): + return prefix if prefix == "" else prefix + "." + + if hasattr(mod, "activation_post_process"): + target_dict[get_prefix(prefix) + "activation_post_process"] = ( + mod.activation_post_process + ) + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_observer_dict(child, target_dict, module_prefix) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_fx.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..ce08882b8ddf295e7bf5bd79dc0b5b27322cf16c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_fx.py @@ -0,0 +1,759 @@ +import copy +import typing_extensions +import warnings +from typing import Any, Optional, Union + +import torch +from torch.fx import GraphModule +from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY + +from .backend_config import BackendConfig, get_tensorrt_backend_config # noqa: F401 +from .fx.convert import convert +from .fx.custom_config import ConvertCustomConfig, FuseCustomConfig, PrepareCustomConfig +from .fx.fuse import fuse # noqa: F401 +from .fx.graph_module import ObservedGraphModule # noqa: F401 +from .fx.prepare import prepare # noqa: F401 +from .fx.tracer import QuantizationTracer, Scope, ScopeContextManager # noqa: F401 +from .fx.utils import ( # noqa: F401 + get_custom_module_class_keys, + get_skipped_module_name_and_classes, +) +from .qconfig_mapping import QConfigMapping +from .utils import DEPRECATION_WARNING + + +def attach_preserved_attrs_to_model( + model: Union[GraphModule, torch.nn.Module], + preserved_attrs: dict[str, Any], +) -> None: + """Store preserved attributes to the model.meta so that it can be preserved during deepcopy""" + model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment] + # set the preserved attributes in the model so that user can call + # model.attr as they do before calling fx graph mode quantization + for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr] + setattr(model, attr_name, attr) + + +def _check_is_graph_module(model: torch.nn.Module) -> None: + if not isinstance(model, GraphModule): + raise ValueError( + "input model must be a GraphModule, " + + "Got type:" + + str(type(model)) + + " Please make " + + "sure to follow the tutorials." + ) + + +def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None: + """Attach meta field to all nodes of the graph if it does not exist, + meta field is a field stores some meta information about the node, such + as dtype and shape information for output of the node, this only exists + if the program is captured by make_fx (used in quantize_pt2e flow), if + the program is captured by torch.fx symbolic tracing, this field may not exist, + so we add it here to avoid checking this all over the places + """ + for node in model.graph.nodes: + if not hasattr(node, "meta"): + node.meta = {} + + +def _swap_ff_with_fxff(model: torch.nn.Module) -> None: + r"""Swap FloatFunctional with FXFloatFunctional""" + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + _swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + +def _fuse_fx( + model: GraphModule, + is_qat: bool, + fuse_custom_config: Union[FuseCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Internal helper function to fuse modules in preparation for quantization + + Args: + model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) + """ + _check_is_graph_module(model) + return fuse(model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator] + + +def _prepare_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + is_qat: bool, + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + _equalization_config: Optional[Union[QConfigMapping, dict[str, Any]]] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, +) -> GraphModule: + r"""Internal helper function for prepare_fx + Args: + `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`: + see docs for :func:`~torch.ao.quantization.prepare_fx` + `is_standalone_module`: a boolean flag indicates whether we are + quantizing a standalone module or not, a standalone module + is a submodule of the parent module that is not inlined in the + forward graph of the parent module, + the way we quantize standalone module is described in: + :func:`~torch.ao.quantization._prepare_standalone_module_fx` + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(prepare_custom_config, dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + # swap FloatFunctional with FXFloatFunctional + _swap_ff_with_fxff(model) + + skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes( + prepare_custom_config, is_standalone_module + ) + preserved_attr_names = prepare_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(model, attr) + for attr in preserved_attr_names + if hasattr(model, attr) + } + # symbolically trace the model + tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type] + graph_module = GraphModule(model, tracer.trace(model)) + _attach_meta_to_node_if_not_exist(graph_module) + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + prepare_custom_config.preserved_attributes + ) + graph_module = _fuse_fx(graph_module, is_qat, fuse_custom_config, backend_config) + prepared = prepare( + graph_module, + qconfig_mapping, + is_qat, + tracer.node_name_to_scope, + example_inputs=example_inputs, + prepare_custom_config=prepare_custom_config, + _equalization_config=_equalization_config, + backend_config=backend_config, + is_standalone_module=is_standalone_module, + ) # type: ignore[operator] + + attach_preserved_attrs_to_model(prepared, preserved_attrs) + return prepared + + +def _prepare_standalone_module_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + is_qat: bool, + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""[Internal use only] Prepare a standalone module, so that it can be used when quantizing the + parent module. + standalone_module means it a submodule that is not inlined in parent module, + and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + + * model(GraphModule): prepared standalone module. It has these attributes in + model.meta: + + * `standalone_module_input_quantized_idxs(List[Int])`: a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + * `standalone_module_output_quantized_idxs(List[Int])`: a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + + """ + return _prepare_fx( + model, + qconfig_mapping, + is_qat, + example_inputs, + prepare_custom_config, + backend_config=backend_config, + is_standalone_module=True, + ) + + +def fuse_fx( + model: torch.nn.Module, + fuse_custom_config: Union[FuseCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. + Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py + + Args: + + * `model` (torch.nn.Module): a torch.nn.Module model + * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx. + See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details + Example:: + + from torch.ao.quantization import fuse_fx + + m = Model().eval() + m = fuse_fx(m) + + """ + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.", + FutureWarning, + stacklevel=2, + ) + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") + preserved_attr_names = fuse_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(model, attr) + for attr in preserved_attr_names + if hasattr(model, attr) + } + + graph_module = torch.fx.symbolic_trace(model) + _attach_meta_to_node_if_not_exist(graph_module) + graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config) + + attach_preserved_attrs_to_model(graph_module, preserved_attrs) + return graph_module + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + _equalization_config: Optional[Union[QConfigMapping, dict[str, Any]]] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r""" Prepare a model for post training quantization + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + + * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is + quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping` + for more details + + * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model, + Tuple of positional args (keyword args can be passed as positional args as well) + + * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool. + See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details + + * `_equalization_config`: config for specifying how to perform equalization on the model + + * `backend_config` (BackendConfig): config that specifies how operators are quantized + in a backend, this includes how the operators are observed, + supported fusion patterns, how quantize/dequantize ops are + inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A GraphModule with observer (configured by qconfig_mapping), ready for calibration + + Example:: + + import torch + from torch.ao.quantization import get_default_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_fx + + class Submodule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + x = self.linear(x) + return x + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=MinMaxObserver.with_args(dtype=torch.qint8), + # weight=MinMaxObserver.with_args(dtype=torch.qint8)) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qconfig_mapping("fbgemm") + + # We can customize qconfig_mapping in different ways. + # e.g. set the global qconfig, which means we will use the same qconfig for + # all operators in the model, this can be overwritten by other settings + # qconfig_mapping = QConfigMapping().set_global(qconfig) + # e.g. quantize the linear submodule with a specific qconfig + # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig) + # e.g. quantize all nn.Linear modules with a specific qconfig + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping` + # argument + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_fx` inserts observers in the model based on qconfig_mapping and + # backend_config. If the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert observer modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # + # Example: + # in qconfig_mapping, user sets linear module to be quantized with quint8 for + # activation and qint8 for weight: + # qconfig = torch.ao.quantization.QConfig( + # observer=MinMaxObserver.with_args(dtype=torch.quint8), + # weight=MinMaxObserver.with-args(dtype=torch.qint8)) + # Note: current qconfig api does not support setting output observer, but + # we may extend this to support these more fine grained control in the + # future + # + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # in backend config, linear module also supports in this configuration: + # weighted_int8_dtype_config = DTypeConfig( + # input_dtype=torch.quint8, + # output_dtype=torch.quint8, + # weight_dtype=torch.qint8, + # bias_type=torch.float) + + # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \ + # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + # .add_dtype_config(weighted_int8_dtype_config) \ + # ... + + # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config) + # `prepare_fx` will check that the setting requested by suer in qconfig_mapping + # is supported by the backend_config and insert observers and fake quant modules + # in the model + prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) + # Run calibration + calibrate(prepared_model, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") + return _prepare_fx( + model, + qconfig_mapping, + False, # is_qat + example_inputs, + prepare_custom_config, + _equalization_config, + backend_config, + ) + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, dict[str, Any]], + example_inputs: tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Prepare a model for quantization aware training + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx` + * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx` + * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx` + * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx` + + Return: + A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for + quantization aware training + + Example:: + + import torch + from torch.ao.quantization import get_default_qat_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_qat_fx + + + class Submodule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + x = self.linear(x) + return x + + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + + # initialize a floating point model + float_model = M().train() + # (optional, but preferred) load the weights from pretrained model + # float_model.load_weights(...) + + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)), + # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8))) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm") + + # We can customize qconfig_mapping in different ways, please take a look at + # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways + # to configure this + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and + # backend_config, if the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert fake_quantize modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of + # how qconfig_mapping interacts with backend_config + prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs) + # Run training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") + return _prepare_fx( + model, + qconfig_mapping, + True, # is_qat + example_inputs, + prepare_custom_config, + backend_config=backend_config, + ) + + +def _convert_fx( + graph_module: GraphModule, + is_reference: bool, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + is_standalone_module: bool = False, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + is_decomposed: bool = False, + keep_original_weights: bool = False, +) -> GraphModule: + """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`""" + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.", + FutureWarning, + stacklevel=3, + ) + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + _check_is_graph_module(graph_module) + preserved_attr_names = convert_custom_config.preserved_attributes + preserved_attrs = { + attr: getattr(graph_module, attr) + for attr in preserved_attr_names + if hasattr(graph_module, attr) + } + + quantized = convert( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module, + _remove_qconfig_flag=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=is_decomposed, + keep_original_weights=keep_original_weights, + ) + + attach_preserved_attrs_to_model(quantized, preserved_attrs) + return quantized + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, + keep_original_weights: bool = False, +) -> GraphModule: + r"""Convert a calibrated or trained model to a quantized model + + Args: + * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + + The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`, + with the same values or `None`. Additional keys can be specified with values set to `None`. + + For each entry whose value is set to None, we skip quantizing that entry in the model:: + + qconfig_mapping = QConfigMapping + .set_global(qconfig_from_prepare) + .set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add + .set_object_type(torch.nn.functional.linear, qconfig_from_prepare) + .set_module_name("foo.bar", None) # skip quantizing module "foo.bar" + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend, this includes quantization + mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.), + observer placement for each operators and fused operators. + See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A quantized model (torch.nn.Module) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # convert_fx converts a calibrated/trained model to a quantized model for the + # target hardware, this includes converting the model first to a reference + # quantized model, and then lower the reference quantized model to a backend + # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and + # they share the same set of quantized operators, so we are using the same + # lowering procedure + # + # backend_config defines the corresponding reference quantized module for + # the weighted modules in the model, e.g. nn.Linear + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + quantized_model = convert_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") + return _convert_fx( + graph_module, + is_reference=False, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + keep_original_weights=keep_original_weights, + ) + + +def convert_to_reference_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Convert a calibrated or trained model to a reference quantized model, + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = convert_to_reference_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx") + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + ) + + +def _convert_to_reference_decomposed_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, + qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None, + backend_config: Union[BackendConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""Convert a calibrated or trained model to a reference quantized model, with + decomposed representation for quantized Tensor + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Note: this is not public API + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) + + """ + torch._C._log_api_usage_once( + "quantization_api.quantize_fx._convert_to_reference_decomposed_fx" + ) + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=False, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=True, + ) + + +def _convert_standalone_module_fx( + graph_module: GraphModule, + is_reference: bool = False, + convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None, +) -> GraphModule: + r"""[Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx` + and convert it to a quantized model + + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details + """ + return _convert_fx( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module=True, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_jit.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..38d9cd6b8b765e7a003be1745f4770948cc3c227 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_jit.py @@ -0,0 +1,421 @@ +# mypy: allow-untyped-defs + +import torch +from torch.ao.quantization.qconfig import QConfig +from torch.ao.quantization.quant_type import QuantType +from torch.jit._recursive import wrap_cpp_module + + +__all__ = [ + "script_qconfig", + "script_qconfig_dict", + "fuse_conv_bn_jit", + "prepare_jit", + "prepare_dynamic_jit", + "convert_jit", + "convert_dynamic_jit", + "quantize_jit", + "quantize_dynamic_jit", +] + + +def _check_is_script_module(model): + if not isinstance(model, torch.jit.ScriptModule): + raise ValueError("input must be a script module, got: " + str(type(model))) + + +def _check_forward_method(model): + if not model._c._has_method("forward"): + raise ValueError("input script module does not have forward method") + + +def script_qconfig(qconfig): + r"""Instantiate the activation and weight observer modules and script + them, these observer module instances will be deepcopied during + prepare_jit step. + """ + return QConfig( + activation=torch.jit.script(qconfig.activation())._c, + weight=torch.jit.script(qconfig.weight())._c, + ) + + +def script_qconfig_dict(qconfig_dict): + r"""Helper function used by `prepare_jit`. + Apply `script_qconfig` for all entries in `qconfig_dict` that is + not None. + """ + return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()} + + +def fuse_conv_bn_jit(model, inplace=False): + r"""Fuse conv - bn module + Works for eval model only. + + Args: + model: TorchScript model from scripting or tracing + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit") + model_c = model._c + model_c = torch._C._jit_pass_fold_convbn(model_c) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): + _check_is_script_module(model) + _check_forward_method(model) + if not all(isinstance(x, str) for x in qconfig_dict.keys()): + raise ValueError("qconfig_dict should only contain names(str) as keys.") + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observers( + model._c, "forward", scripted_qconfig_dict, inplace, quant_type + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def _prepare_ondevice_jit( + model, + qconfig_dict, + method_name="forward", + inplace=False, + quant_type=QuantType.STATIC, +): + _check_is_script_module(model) + if not all(isinstance(x, str) for x in qconfig_dict.keys()): + raise ValueError("qconfig_dict should only contain names(str) as keys.") + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + method_graph = model._c._get_method(method_name).graph + torch._C._jit_pass_inline(method_graph) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq( + model._c, method_name, scripted_qconfig_dict, inplace, quant_type + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def prepare_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC) + + +def prepare_dynamic_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC) + + +def _prepare_ondevice_dynamic_jit( + model, qconfig_dict, method_name="forward", inplace=False +): + return _prepare_ondevice_jit( + model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC + ) + + +def _convert_jit( + model, inplace=False, debug=False, quant_type=QuantType.STATIC, preserved_attrs=None +): + _check_is_script_module(model) + model.eval() + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant( + model_c, "forward", inplace, debug, quant_type + ) + if not debug: + is_xpu = all(p.device.type == "xpu" for p in model.parameters()) + if not is_xpu: + # Moving model parameters to CPU since quantized operators + # are only supported on CPU and XPU right now + model.cpu() + if preserved_attrs is None: + preserved_attrs = [] + model_c = torch._C._jit_pass_quant_finalize( + model_c, quant_type, preserved_attrs + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + + +def _convert_ondevice_jit( + model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC +): + _check_is_script_module(model) + assert quant_type == QuantType.DYNAMIC, ( + "This API, while should work for static quant, is only tested for dynamic quant." + ) + assert not method_name.startswith("observe_"), ( + "Pass in valid method to be quantized, e.g. forward" + ) + observe_method_name = "observe_" + method_name + quantize_method_name = "quantize_" + method_name + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq( + model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC + ) + model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq( + model_c, QuantType.DYNAMIC, quantize_method_name + ) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + + +def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") + return _convert_jit( + model, + inplace, + debug, + quant_type=QuantType.STATIC, + preserved_attrs=preserved_attrs, + ) + + +def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") + return _convert_jit( + model, + inplace, + debug, + quant_type=QuantType.DYNAMIC, + preserved_attrs=preserved_attrs, + ) + + +def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): + return _convert_ondevice_jit( + model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC + ) + + +def _quantize_ondevice_dynamic_jit_impl( + model, qconfig_dict, method_name, inplace=False +): + model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace) + model = _convert_ondevice_dynamic_jit(model, method_name, inplace) + return model + + +def _quantize_jit( + model, + qconfig_dict, + run_fn=None, + run_args=None, + inplace=False, + debug=False, + quant_type=QuantType.STATIC, +): + # Always do inplace convert because the Tensor is already + # copied in prepare_jit when inplace is False + if quant_type == QuantType.DYNAMIC: + model = prepare_dynamic_jit(model, qconfig_dict, inplace) + model = convert_dynamic_jit(model, True, debug) + else: + assert run_fn, ( + "Must provide calibration function for post training static quantization" + ) + assert run_args, ( + "Must provide calibration dataset for post training static quantization" + ) + model = prepare_jit(model, qconfig_dict, inplace) + run_fn(model, *run_args) + model = convert_jit(model, True, debug) + + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + + +def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, empty key means the qconfig will be applied + to whole model unless it's overwritten by more specific configurations, the + qconfig for each module is either found in the dictionary or fallback to + the qconfig of parent module. + + Right now qconfig_dict is the only way to configure how the model is quantized, + and it is done in the granularity of module, that is, we only support one type + of qconfig for each torch.nn.Module, and the qconfig for sub module will + override the qconfig for parent module, empty string means global configuration. + `run_fn`: a calibration function for calibrating the prepared model + `run_args`: positional arguments for `run_fn` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import get_default_qconfig + from torch.ao.quantization import quantize_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + + quantized_model = quantize_jit( + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit") + return _quantize_jit( + model, + qconfig_dict, + run_fn, + run_args, + inplace, + debug, + quant_type=QuantType.STATIC, + ) + + +def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization import quantize_dynamic_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + + + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + + quantized_model = quantize_dynamic_jit( + ts_model, {"": qconfig}, calibrate, [data_loader_test] + ) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") + return _quantize_jit( + model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC + ) + + +def _quantize_ondevice_dynamic_jit( + model, qconfig_dict, method_name="forward", inplace=False +): + r"""Prepares the input float TorchScript model with + *on-device* post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + `method_name`: Name of the method within the model, to be prepared for quantization + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + + Return: + TorchScript model that is ready for on device quantization. + This means that the returned + model has: + - Method is inlined. + - Model has observer modules inserted in the model. + - Model has packed params inserted in the model. However they are empty as in they dont + contain valid quantized weights. + - observe_ is added that observe the values to be quantized. + - reset_observers_ to reset observers. + - quantize_ is added to the model. + - This method extract scale, zero points. + - Quantizes observed weights. + - Creates packed params from it and update the attribute of the model with the new values + for the packed params. + - Reset the original fp32 weights with empty tensor using SetAttr. + - quantized_ is added to the model. + - This method uses quantized weights and quantized linear ops instead of fp32 op. + - This method should be used for inference post PTQ. + - Note that all method's signatures should be the same as method_name. + + Later on device: + - Run reset_observers_ + - Run observe_ + - Run quantize_ + - Now model can be saved and loaded later. + - Run model with quantized_ + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit + + ts_model = torch.jit.script( + float_model.eval() + ) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig("fbgemm") + quant_ready_model = _quantize_ondevice_dynamic_jit( + ts_model, {"": qconfig}, "forward", True + ) + ``` + """ + return _quantize_ondevice_dynamic_jit_impl( + model, qconfig_dict, method_name, inplace=inplace + ) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_pt2e.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..3188eba9e96c5b2883c5875025563e6700472d71 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/quantize_pt2e.py @@ -0,0 +1,262 @@ +import typing_extensions + +import torch +from torch._export.passes.constant_folding import constant_fold +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.quantizer import ( # noqa: F401 + DerivedQuantizationSpec, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassManager + +from .pt2e.prepare import prepare +from .pt2e.qat_utils import _fold_conv_bn_qat, _fuse_conv_bn_qat +from .pt2e.representation import reference_representation_rewrite +from .pt2e.utils import _disallow_eval_train, _fuse_conv_bn_, _get_node_name_to_scope +from .quantize_fx import _convert_to_reference_decomposed_fx +from .utils import DEPRECATION_WARNING + + +__all__ = [ + "prepare_pt2e", + "prepare_qat_pt2e", + "convert_pt2e", +] + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for post training quantization + + Args: + * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API. + * `quantizer`: A backend specific quantizer that conveys how user want the + model to be quantized. Tutorial for how to write a quantizer can be found here: + https://pytorch.org/tutorials/prototype/pt2e_quantizer.html + + Return: + A GraphModule with observer (based on quantizer annotation), ready for calibration + + Example:: + + import torch + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_pt2e(m, quantizer) + + # run calibration + # calibrate(m, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + # TODO: check qconfig_mapping to make sure conv and bn are both configured + # to be quantized before fusion + # TODO: (maybe) rewrite this with subgraph_rewriter + _fuse_conv_bn_(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + model = prepare( + model, + node_name_to_scope, + is_qat=False, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def prepare_qat_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for quantization aware training + + Args: + * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + + Return: + A GraphModule with fake quant modules (based on quantizer annotation), ready for + quantization aware training + + Example:: + import torch + from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = torch.export.export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_qat_pt2e(m, quantizer) + + # run quantization aware training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + model = quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + # Perform fusion after annotate to avoid quantizing ops in the new + # subgraph that don't need to be quantized + # TODO: only fuse if conv and bn are both configured to be quantized + _fuse_conv_bn_qat(model) + model = prepare( + model, + node_name_to_scope, + is_qat=True, + obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback, + ) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + + +_QUANT_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, +] + + +def _quant_node_constraint(n: Node) -> bool: + """If there is any pure ops between get_attr and quantize op they will be const propagated + e.g. get_attr(weight) -> transpose -> quantize -> dequantize* + (Note: dequantize op is not going to be constant propagated) + + This filter is added because we don't want to constant fold the things that are not + related to quantization + """ + return n.op == "call_function" and n.target in _QUANT_OPS + + +@typing_extensions.deprecated(DEPRECATION_WARNING) +def convert_pt2e( + model: GraphModule, + use_reference_representation: bool = False, + fold_quantize: bool = True, +) -> GraphModule: + """Convert a calibrated/trained model to a quantized model + + Args: + * `model` (torch.fx.GraphModule): calibrated/trained model + * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not + * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not + + Returns: + quantized model, either in q/dq representation or reference representation + + Example:: + + # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training + # `convert_pt2e` produces a quantized model that represents quantized computation with + # quantize dequantize ops and fp32 ops by default. + # Please refer to + # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model + # for detailed explanation of output quantized model + quantized_model = convert_pt2e(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") + if not isinstance(use_reference_representation, bool): + raise ValueError( + "Unexpected argument type for `use_reference_representation`, " + f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e" + ) + original_graph_meta = model.meta + model = _convert_to_reference_decomposed_fx(model) + model = _fold_conv_bn_qat(model) + + pm = PassManager([DuplicateDQPass()]) + model = pm(model).graph_module + + pm = PassManager([PortNodeMetaForQDQ()]) + model = pm(model).graph_module + + if fold_quantize: + constant_fold(model, _quant_node_constraint) + + if use_reference_representation: + model = reference_representation_rewrite(model) + + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/stubs.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfffcb756f76500451611daffaed9655bf95bf1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/stubs.py @@ -0,0 +1,74 @@ +from typing import Any, Optional + +import torch +from torch import nn +from torch.ao.quantization import QConfig + + +__all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"] + + +class QuantStub(nn.Module): + r"""Quantize stub module, before calibration, this is same as an observer, + it will be swapped as `nnq.Quantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig: Optional[QConfig] = None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class DeQuantStub(nn.Module): + r"""Dequantize stub module, before calibration, this is same as identity, + this will be swapped as `nnq.DeQuantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + + def __init__(self, qconfig: Optional[Any] = None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class QuantWrapper(nn.Module): + r"""A wrapper class that wraps the input module, adds QuantStub and + DeQuantStub and surround the call to module with call to quant and dequant + modules. + + This is used by the `quantization` utility functions to add the quant and + dequant modules, before `convert` function `QuantStub` will just be observer, + it observes the input tensor, after `convert`, `QuantStub` + will be swapped to `nnq.Quantize` which does actual quantization. Similarly + for `DeQuantStub`. + """ + + quant: QuantStub + dequant: DeQuantStub + module: nn.Module + + def __init__(self, module: nn.Module): + super().__init__() + qconfig = getattr(module, "qconfig", None) + self.add_module("quant", QuantStub(qconfig)) + self.add_module("dequant", DeQuantStub(qconfig)) + self.add_module("module", module) + self.train(module.training) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + X = self.quant(X) + X = self.module(X) + return self.dequant(X) diff --git a/.venv/lib/python3.12/site-packages/torch/ao/quantization/utils.py b/.venv/lib/python3.12/site-packages/torch/ao/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac3112ec072f1f40d3e60f1c50cf5ad02c33da7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/ao/quantization/utils.py @@ -0,0 +1,838 @@ +# mypy: allow-untyped-defs +""" +Utils shared by different modes of quantization (eager/graph) +""" + +import functools +import warnings +from collections import OrderedDict +from inspect import getfullargspec, signature +from typing import Any, Callable, Optional, Union + +import torch +from torch.ao.quantization.quant_type import QuantType +from torch.fx import Node +from torch.nn.utils.parametrize import is_parametrized + + +NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] +NodePattern.__module__ = "torch.ao.quantization.utils" + +# This is the Quantizer class instance from torch/quantization/fx/quantize.py. +# Define separately to prevent circular imports. +# TODO(future PR): improve this. +# make this public once fixed (can't be public as is because setting the module directly +# doesn't work) +QuantizerCls = Any + +# Type for fusion patterns, it can be more complicated than the following actually, +# see pattern.md for docs +# TODO: not sure if typing supports recursive data types +Pattern = Union[ + Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any +] +Pattern.__module__ = "torch.ao.quantization.utils" + + +# TODO: maybe rename this to MatchInputNode +class MatchAllNode: + """A node pattern that matches all nodes, used in defining + fusion patterns in FX Graph Mode Quantization + """ + + +module_type_list = { + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.Identity, + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, +} +func_list = { + torch.nn.functional.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.silu, + torch.nn.functional.mish, + torch.nn.functional.dropout, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.transpose, + torch.repeat_interleave, + torch.sigmoid, + torch.squeeze, + torch.stack, + torch.sum, + torch.tanh, + torch.unsqueeze, + torch.cat, +} +method_list = { + torch.mean, + "relu", + "relu_", + "contiguous", + "detach", + "detach_", + "hardsigmoid", + "hardsigmoid_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "sigmoid", + "sigmoid_", + "size", + "squeeze", + "squeeze_", + "tanh", + "tanh_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", +} + + +# TODO: not used now, remove +def check_node(node, modules): + # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = ( + node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + ) + return is_call_function, is_call_method, is_call_module + + +def get_combined_dict(default_dict, additional_dict): + """ + Combines two dictionaries. + + This function takes two dictionaries as input and returns a new dictionary + that contains all the key-value pairs from both input dictionaries. + If there are any duplicate keys in the `additional_dict`, the values + from the `additional_dict` will overwrite those in the `default_dict`. + Args: + default_dict (dict): The main dictionary that will be used as the base + additional_dict (dict): The dictionary used to update `default_dict` + + Returns: + dict: The resulting dictionary + Example: + >>> x = dict(a=1, b=1) + >>> y = dict(b=2, c=3) + >>> get_combined_dict(x, y) + {'a': 1, 'b': 2, 'c': 3} + """ + d = default_dict.copy() + d.update(additional_dict) + return d + + +def is_per_tensor(qscheme): + return qscheme == torch.per_tensor_affine or qscheme == torch.per_tensor_symmetric + + +def is_per_channel(qscheme): + return qscheme in [ + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + ] + + +def getattr_from_fqn(obj: Any, fqn: str) -> Any: + """ + Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. + """ + return functools.reduce(getattr, fqn.split("."), obj) + + +def to_underlying_dtype(qdtype): + DTYPE_MAPPING = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.uint8: torch.uint8, + torch.int8: torch.int8, + torch.uint16: torch.uint16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.float8_e5m2: torch.float8_e5m2, + torch.float8_e4m3fn: torch.float8_e4m3fn, + } + assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype) + return DTYPE_MAPPING[qdtype] + + +def get_qparam_dict(observer_or_fake_quant): + from torch.ao.quantization.observer import PlaceholderObserver + + qscheme = getattr(observer_or_fake_quant, "qscheme", None) + dtype = observer_or_fake_quant.dtype + qparams = {"qscheme": qscheme, "dtype": dtype} + + if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver): + return {"qscheme": None, "dtype": dtype} + + if is_per_tensor(qscheme): + qscheme = torch.per_tensor_affine + elif is_per_channel(qscheme): + # change symmetric to affine since we do not have symmetric + # quantized Tensor + if qscheme == torch.per_channel_symmetric: + qscheme = torch.per_channel_affine + qparams["axis"] = observer_or_fake_quant.ch_axis + else: + raise RuntimeError(f"Unrecognized qscheme: {qscheme}") + # update qscheme, since we don't have symmetric quant qscheme + # in quantized Tensor + qparams["qscheme"] = qscheme + + scale, zero_point = observer_or_fake_quant.calculate_qparams() + qparams["scale"] = scale + qparams["zero_point"] = zero_point + + if hasattr(observer_or_fake_quant, "quant_min"): + qparams["quant_min"] = observer_or_fake_quant.quant_min + if hasattr(observer_or_fake_quant, "quant_max"): + qparams["quant_max"] = observer_or_fake_quant.quant_max + + return qparams + + +def get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig +): + """Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + class_mapping = custom_module_class_mapping.get(quant_type, {}) + assert type(custom_module) in class_mapping, ( + "did not find corresponding observed " + f"module class for {type(custom_module)} in mapping: {class_mapping}" + ) + return class_mapping[type(custom_module)] + + +def activation_dtype(qconfig): + assert qconfig is not None + activation = qconfig.activation() + return activation.dtype + + +def weight_dtype(qconfig): + assert qconfig is not None + weight = qconfig.weight() + return weight.dtype + + +def activation_is_statically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and (not activation_is_dynamically_quantized(qconfig)) + + +def activation_is_dynamically_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + dynamically quantized or not, this includes dynamically quantizing to + quint8, qint8 and float16 + """ + _activation_dtype, _, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return activation_is_dynamic + + +def activation_is_int8_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int8 or not, this includes quantizing to quint8, qint8 + """ + return activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.uint8, + torch.int8, + ] + + +def activation_is_int32_quantized(qconfig): + """Given a qconfig, decide if the activation needs to be + quantized to int32 or not + """ + return activation_dtype(qconfig) in [torch.qint32, torch.int32] + + +def weight_is_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be + quantized or not + """ + return weight_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.float16, + torch.quint4x2, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + + +def weight_is_statically_quantized(qconfig): + """Given a qconfig, decide if the weight needs to be statically + quantized or not + """ + return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] + + +def op_is_int8_dynamically_quantized(qconfig) -> bool: + """Given a qconfig, returns True if this op is using int8 dynamic + quantization + """ + activation_dtype, weight_dtype, activation_is_dynamic = get_qconfig_dtypes(qconfig) + return ( + activation_dtype in [torch.quint8, torch.uint8] + and + # for now, the lines below assume fbgemm or qnnpack + weight_dtype in [torch.qint8, torch.int8] + and activation_is_dynamic + ) + + +def get_qconfig_dtypes(qconfig): + r"""returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_is_dynamic) + """ + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + act_is_dynamic = getattr(activation, "is_dynamic", False) + return (activation.dtype, weight.dtype, act_is_dynamic) + + +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [ + torch.quint8, + torch.qint8, + torch.quint4x2, + torch.qint32, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] + if weight.dtype in static_dtypes: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype in static_dtypes: + return QuantType.STATIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if hasattr(activation, "is_dynamic") and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype == torch.float16: + return QuantType.STATIC + + raise Exception( # noqa: TRY002 + f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype})," + f"weight({weight.dtype})" + ) + + +def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: + """Checks if the given minimum and maximum values are valid, meaning that + they exist and the min value is less than the max value. + """ + if min_val.numel() == 0 or max_val.numel() == 0: + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + return False + + if min_val.dim() == 0 or max_val.dim() == 0: + if min_val == float("inf") and max_val == float("-inf"): + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + + return False + + assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" + else: + assert torch.all(min_val <= max_val), ( + f"min {min_val} should be less than max {max_val}" + ) + + return True + + +def calculate_qmin_qmax( + quant_min: int, + quant_max: int, + has_customized_qrange: bool, + dtype: torch.dtype, + reduce_range: bool, +) -> tuple[int, int]: + r"""Calculates actual qmin and qmax based on the quantization range, + observer datatype and if range is reduced. + """ + # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted. + if has_customized_qrange: + # This initialization here is to be resolve TorchScript compilation issues and allow + # using of refinement to decouple initial_qmin and initial_qmax from quantization range. + # The actual values of initial_qmin and initial_qmax will be reset below. + if dtype in [torch.qint32, torch.int32]: + initial_quant_min, initial_quant_max = 0, 2**32 - 1 + else: + initial_quant_min, initial_quant_max = 0, 255 + # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the + # attribute from Optional valid integers for use, based on TorchScript's requirements. + custom_quant_min, custom_quant_max = quant_min, quant_max + if custom_quant_min is not None and custom_quant_max is not None: + initial_quant_min, initial_quant_max = ( + custom_quant_min, + custom_quant_max, + ) + + qrange_len = initial_quant_max - initial_quant_min + 1 + if dtype in [torch.qint8, torch.int8]: + assert 0 < qrange_len <= 256, ( + "quantization range should be positive and not exceed the maximum bit range (=256)." + ) + elif dtype in [torch.qint32, torch.int32]: + assert 0 < qrange_len <= 2**32, ( + "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + ) + if reduce_range: + quant_min, quant_max = quant_min // 2, quant_max // 2 + else: + # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. + if dtype in [torch.qint8, torch.int8]: + if reduce_range: + quant_min, quant_max = -64, 63 + else: + quant_min, quant_max = -128, 127 + elif dtype in [torch.quint8, torch.uint8]: + if reduce_range: + quant_min, quant_max = 0, 127 + else: + quant_min, quant_max = 0, 255 + elif dtype in [torch.qint32, torch.int32]: + quant_min, quant_max = -1 * (2**31), (2**31) - 1 + elif dtype in [torch.uint16]: + quant_min, quant_max = 0, 2**16 - 1 + elif dtype in [torch.int16]: + quant_min, quant_max = -(2**15), 2**15 - 1 + else: + quant_min, quant_max = 0, 15 + return quant_min, quant_max + + +def _parent_name(target): + """ + Turn 'foo.bar' into ['foo', 'bar'] + """ + r = target.rsplit(".", 1) + if len(r) == 1: + return "", r[0] + else: + return r[0], r[1] + + +def has_no_children_ignoring_parametrizations(module): + """ + Checks if module._modules is empty or + if module is a parametrization, checks that module._modules only has + the 'parametrizations' module + """ + if len(module._modules) == 0: + return True + elif is_parametrized(module): + return len(module._modules) == 1 and "parametrizations" in module._modules + else: + return False + + +def _get_path_of_module( + root: torch.nn.Module, submodule: torch.nn.Module +) -> Optional[str]: + """Get the path (fully qualified name) of a submodule + + Example:: + + >> class M(torch.nn.Module): + def __init__(self) -> None: + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + return self.linear(x) + + >> m = M() + >> l = m.linear + >> _get_path_of_module(m, l) + "linear" + """ + for n, p in root.named_modules(): + if submodule is p: + return n + return None + + +def _get_signature_locals(f: Callable, loc: dict[str, Any]) -> dict[str, Any]: + """Get local keyword arguments + + Example:: + + >> def f(self, a, b=9): + pass + >> loc = {"a": 6, "c": 7} + >> _get_signature_locals(f, loc) + {"a": 6} + """ + return {k: v for k, v in loc.items() if k in signature(f).parameters} + + +def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": + """Get all default keyword arguments from function signature + + Example:: + + >> def f(self, a, b=9): + pass + >> _get_default_kwargs(f) + {"b": 9} + """ + kwargs = {} + for name, param in signature(f).parameters.items(): + if param.default is not param.empty: + kwargs[name] = param.default + elif param.kind is param.VAR_POSITIONAL: + kwargs[name] = () + elif param.kind is param.VAR_KEYWORD: + kwargs[name] = {} + return OrderedDict(kwargs) + + +def _normalize_kwargs(func: Callable, loc: dict[str, Any]) -> "OrderedDict[str, Any]": + """Given a function and local function arguments, normalize the keyword + arguments by filling in default arguments from function signature + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> loc = {"key2": 6} + >> _normalize_kwargs(f, loc) + {"key1": 3, "key2": 6} + """ + default_kwargs = _get_default_kwargs(func) + local_kwargs = _get_signature_locals(func, loc) + normalized_kwargs = default_kwargs.copy() + for attr, val in local_kwargs.items(): + if attr in normalized_kwargs: + # override the default keyword arguments + normalized_kwargs[attr] = val + return normalized_kwargs + + +def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert quant_min <= 0 <= quant_max, ( + "Used-specified quantization range must include 0." + ) + assert quant_min < quant_max, ( + "qmin must be strictly less than qmax for user-specified quantization range." + ) + + +# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme +# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer +# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change +# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) +def determine_qparams( + min_val: torch.Tensor, + max_val: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + eps: torch.Tensor, + has_customized_qrange: bool, + qscheme: torch.qscheme = torch.per_tensor_affine, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor( + [0], device=min_val.device.type + ) + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + eps = eps.to(device) + + if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, eps) + if dtype in [torch.uint8, torch.quint8]: + if has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale.to(torch.double), zero_point.to(torch.int64) + + +def _get_num_pos_args(f: Callable) -> int: + """Get number of positional args for a function + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> _get_num_pos_args(f) + 3 + """ + return len(getfullargspec(f).args) + + +def get_fqn_to_example_inputs( + model: torch.nn.Module, example_inputs: tuple[Any, ...] +) -> dict[str, tuple[Any, ...]]: + """Given a model and its example inputs, return a dictionary from + fully qualified name of submodules to example_inputs for that submodule, + e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), + "sub.linear1": (tensor4,), ...} + + Used to make quantizing submodules easier now that FX Graph Mode Quantization requires + example inputs. + + Also works for keyword arguments with default values, we would flatten keyword + arguments as positional arguments and fill in the missing keyword args with default + values, e.g. if we have a forward function: + def forward(self, x, key1=3, key2=3): + ... + + and we call it with self.submodule(x, key2=6) + we'll get example_inputs: (x, 3, 6) + + user can also override `key1` with positional arguments as well: + for self.submodule(x, 5, key2=6) + we'll get: (x, 5, 6) + + variable positional arguments and variable positional keyword arguments in forward + function are not supported currently, so please make sure no submodules is using + them. + """ + root = model + fqn_to_example_inputs = {} + + def _patched_module_call(self, *args, **kwargs): + submodule_example_inputs = list(args).copy() + normalized_kwargs = _normalize_kwargs(self.forward, kwargs) + # minus 1 to skipping counting `self` + num_args = _get_num_pos_args(self.forward) - 1 + num_to_pop = num_args - len(submodule_example_inputs) + while num_to_pop and normalized_kwargs: + normalized_kwargs.popitem(last=False) + num_to_pop -= 1 + submodule_example_inputs.extend(normalized_kwargs.values()) + submodule_example_inputs_tuple = tuple(submodule_example_inputs) + fqn = _get_path_of_module(root, self) + if fqn is not None: + fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple + return orig_module_call(self, *args, **kwargs) + + orig_module_call = torch.nn.Module.__call__ + torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign] + try: + model(*example_inputs) + finally: + # restore the module call even if there is an exception + torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign] + return fqn_to_example_inputs + + +def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + devices = {p.device for p in module.parameters()} | { + p.device for p in module.buffers() + } + """ + As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564 + """ + if {torch.device("cpu"), torch.device("meta")} == devices: + warnings.warn( + "Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'." + ) + devices = {torch.device("cpu")} + "" + assert len(devices) <= 1, ( + "prepare only works with cpu or single-device CUDA modules, " + f"but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + return device + + +DEPRECATION_WARNING = ( + "torch.ao.quantization is deprecated and will be removed in 2.10. \n" + "For migrations of users: \n" + "1. Eager mode quantization (torch.ao.quantization.quantize, " + "torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode " + "quantize_ API instead \n" + "2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx," + "torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization " + "API instead (prepare_pt2e, convert_pt2e) \n" + "3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) \n" + "see https://github.com/pytorch/ao/issues/2259 for more details" +) + + +__all__ = [ + "NodePattern", + "Pattern", + "MatchAllNode", + "check_node", + "get_combined_dict", + "is_per_tensor", + "is_per_channel", + "getattr_from_fqn", + "get_qparam_dict", + "get_swapped_custom_module_class", + "activation_dtype", + "weight_dtype", + "activation_is_statically_quantized", + "activation_is_dynamically_quantized", + "activation_is_int8_quantized", + "activation_is_int32_quantized", + "weight_is_quantized", + "weight_is_statically_quantized", + "op_is_int8_dynamically_quantized", + "get_qconfig_dtypes", + "get_quant_type", + "check_min_max_valid", + "calculate_qmin_qmax", + "has_no_children_ignoring_parametrizations", + "get_fqn_to_example_inputs", + "to_underlying_dtype", + "determine_qparams", + "validate_qmin_qmax", + "DEPRECATION_WARNING", +] diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f33e76e08e8e1240f724447ef679d41d324cfb69 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..792ace14c2066fa44f723dbbe6fde8ebb4d5b61c Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d837c46b48f175414cccf83819ff387817863de5 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/_tree_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8d93b3f1d8c85b08dd071bfc189495991d9ceac Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/custom_ops.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d30ccaa16fa60417687c9d926830e312dccfe5bd Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/decomp_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18629eb0451ceb51e9f1ba7dc3894d10c3223fbf Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59fdd4bba5278d09de1691f3fb9a9417ad12173f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/exported_program.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8358b55c763b201c914d52bedd747685fac082e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/graph_signature.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6d0c5831de1f513156988afedcdaa074bfdb2fd Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/export/__pycache__/unflatten.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/export/experimental/__init__.py b/.venv/lib/python3.12/site-packages/torch/export/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2c46108189a21d944b42a4af4547a4344e23559 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/export/experimental/__init__.py @@ -0,0 +1,326 @@ +import copy +import dataclasses +import functools +import types +import typing +import typing_extensions + +import torch +from torch.export.exported_program import _decompose_exported_program + + +def _copy_graph_module_and_signature( + ep: torch.fx.GraphModule, +) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: + # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), + # and this can break placeholder names in some particular cases. + # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. + # So we manually overwrite placeholder names by reading the old graph. + gm = copy.deepcopy(ep.graph_module) + new_graph_signature = copy.deepcopy(ep.graph_signature) + + # iterate over old/new graph modules + for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr] + old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] + new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] + # iterate over placeholders + assert len(old_phs) == len(new_phs) + for old_node, new_node in zip(old_phs, new_phs): + new_node.name = old_node.name + + return gm, new_graph_signature # type: ignore[return-value] + + +def _remove_detach_pass( + gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature +) -> None: + with gm._set_replace_hook(sig.get_replace_hook()): + for node in list(reversed(gm.graph.nodes)): + if node.op != "call_function": + continue + if ( + node.target == torch.ops.aten.detach.default + and len(node.users) == 1 + and next(iter(node.users)).target == torch.ops.aten.detach.default + ): + next(iter(node.users)).replace_all_uses_with(node) + + gm.graph.eliminate_dead_code() + gm.recompile() + + +def _export_forward_backward( + ep: torch.export.ExportedProgram, joint_loss_index: int = 0 +) -> torch.export.ExportedProgram: + """ + WARNING: This API is highly unstable and will be subject to change in the future. + """ + from torch._decomp import core_aten_decompositions + + ep = _decompose_exported_program( + ep, + cia_to_decomp={}, + python_decomp_table=core_aten_decompositions(), + joint_loss_index=joint_loss_index, + # For serialization purpose, we don't want to decompose custom triton ops. + # If users would like to decompose custom triton ops, they could do it + # with run_decompositions() API. + decompose_custom_triton_ops=False, + ) + gm, new_graph_signature = _copy_graph_module_and_signature(ep) + _remove_detach_pass(gm, new_graph_signature) + + return ep._update(gm, new_graph_signature) + + +@typing.no_type_check +def _sticky_export(forward_func, dynamic_shapes_callback=None): + """ + Lazily export the model on first forward call. + Usage: + model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback) + """ + model = forward_func.__self__ + original_forward = forward_func.__func__ + + @functools.wraps(forward_func) + def wrapper(*args, **kwargs): + # Unpatch forward to avoid recursion during export + model.forward = types.MethodType(original_forward, model) + + dynamic_shapes_spec = None + if dynamic_shapes_callback: + dynamic_shapes_spec = dynamic_shapes_callback(*args, **kwargs) + + try: + exported = torch.export.export( + model, + args, + kwargs, + dynamic_shapes=dynamic_shapes_spec, + ).module() + wrapper._exported_artifact = exported + finally: + # Restore the wrapper after export + model.forward = wrapper + + return exported(*args, **kwargs) + + return wrapper + + +@dataclasses.dataclass +class _ExportMethod: + overloads: dict[str, torch.export.ExportedProgram] + fallbacks: list[torch.export.ExportedProgram] + + +_InputT = typing_extensions.ParamSpec("_InputT") +_RetT = typing.TypeVar("_RetT") + + +class _ExportPackage: + """ + An export package is a collection of torch.export()-ed PyTorch models consisting of + a list of exported methods and their corresponding overloads. ExportPackage is introduced + on top of torch.export() to support the following use cases: + - Exporting a model with multiple methods if a model has multiple independent parts. + - Exporting a function with multiple overloads based on tensor shapes or other metadata. + + ExportPackage is designed to contain multiple methods (associated with method names) and for + each method, it can have multiple overloads (associated with overload names). + + Here is an example of the data structure for an ExportPackage: + ``` + ExportPackage( + methods={ + "decoder": ExportMethod( + overloads={ + "prefill": ExportedProgram(...), + "decode": ExportedProgram(...), + }, + fallbacks=[], + ), + "encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]), + }, + ) + ``` + + To export a model into an ExportPackage, users can use the exporter API provided by ExportPackage. + Exporter is a decorator that takes a callable and returns a wrapper. The wrapper will export the + function into an ExportPackage, when it's invoked with some sample inputs (similar to how + torch.compile() works). For more details, please refer to the document on .exporter() method. + + This design allows users to decouple the exported callables from the actual sample inputs which can + be helpful for use cases where the exported callable is hidden behind helper functions or when sample + inpusts are hard to get. + + NOTE: This is an experimental API and anything can be changed in the future. + + Example usage: + ``` + def fn(x): + return x + 1 + + def main(f, x): + x += 1 + ret = f(x) + return ret + 1 + + package = ExportPackage() + main(package.exporter(fn), torch.randn(3, 2)) + ``` + + """ + + def __init__(self) -> None: + self.methods: dict[str, _ExportMethod] = {} + + def _exporter( + self, + method: str, + fn: typing.Callable[_InputT, _RetT], + *, + fallback: str = "once", + ) -> typing.Callable[_InputT, _RetT]: + """ + A function/module decorator that sets up a callable to be exported later invoked. + By default the exporter will only trigger torch.export for once and error on + later invocations. To customize this behavior, users have the following two options: + 1. Call .define_overload() method on the returned wrapper to define an overload. + 2. Adjust the fallback policy using `fallback` argument. + + An "overload" is a named branch for an ExportMethod with a user defined precondition, + typically based on input tensor shapes. It's up to a downstream backend implementation + of ExportMethod to respect the precondition later in inference. + + define_overload() takes arguments like the following: + - A name, for indexing purposes in a backend. + - A callable (spec) that: + - Has the same model input signature as the original model code. + - Returns an optional dynamic shape spec. + + Exporter will only export an overload when the spec callable successfully returns + a result without rasing AssertionError. + + For example: + ``` + package = ExportPackage() + + + def prefill(x, xa, kv_cache): + assert x.shape[1] == 3 + assert kv_cache == {} + + + def decode(x, xa, kv_cache): + assert x.shape[1] > 1 + assert len(kv_cache) > 0 + return {...} # dynamic shape specs here + + + exporter = ( + package.exporter(decoder) + .define_overload("prefill", prefill) + .define_overload("decode", decode) + ) + ``` + + A "fallback" is exported when no overload precondition matches a given set of sample + inputs. Overloads should + Fallbacks don't have names and are ordered in a list. It's up to a backend to decide + which fallback is used amony multiple ones. + + A reference backend implementation of ExportMethod may look like the following: + ``` + def execute(method: ExportMethod, *args, **kwargs): + for overload in method.overloads: + if match_precondition(overload, *args, **kwargs): + return execute_overload(overload, *args, **kwargs) + for fallback in method.fallbacks: + if match_precondition(fallback, *args, **kwargs): + return execute_fallback(fallback, *args, **kwargs) + ``` + + Args: + method(str): The method name for an exported part of PyTorch model. This + will be saved together with the exported/compiled artifacts + in any serialization format and can be used as the key to + index ExportPackage methods later. + fn(callable): A PyTorch function/module to be exported. + fallback(str): The fallback policy to decide when to call torch.export + - "once" is the default policy. Under this policy a PyTorch program is assumed + to be only called once later and an error will be raised for subsequent + runs. + - "error" means the ExportMethod will never have any fallbacks, meaning + users should define all the possible overloads ahead of time. + + """ + + fallbacks: list[torch.export.ExportedProgram] = [] + specs: dict[str, typing.Callable[_InputT, typing.Any]] = {} + overloads: dict[str, torch.export.ExportedProgram] = {} + self.methods[method] = _ExportMethod(fallbacks=fallbacks, overloads=overloads) + + @functools.wraps(fn) + def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] + import torch.export._wrapper_utils + + model: torch.nn.Module + if not isinstance(fn, torch.nn.Module): + model = torch.export._wrapper_utils._WrapperModule(fn) + else: + model = fn + + for k, v in specs.items(): + try: + if isinstance(fn, torch.nn.Module): + dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type] + else: + dynamic_shapes = v(*args, **kwargs) + except AssertionError: + continue + if k not in overloads: + ep = torch.export.export( + model, args, kwargs, dynamic_shapes=dynamic_shapes + ) + overloads[k] = ep + ep = overloads[k] + return ep.module()(*args, **kwargs) + + if fallback == "error": + raise RuntimeError( + f"Exporter: Cannot export fallback {fn} when fallback policy is set to 'error'," + + "please specify an overload or adjust the fallback policy." + ) + elif fallback == "once": + if len(fallbacks) > 0: + raise RuntimeError( + f"Exporter: Cannot export {fn} more than once, " + + "please specify an overload or adjust the fallback policy." + ) + else: + raise RuntimeError(f"Unknown fallback policy: {fallback}") + ep = torch.export.export(model, args, kwargs) + + fallbacks.append(ep) + return ep.module()(*args, **kwargs) + + if isinstance(fn, torch.nn.Module): + _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 + fn, lambda _: _exporter_context + ) + + def _define_overload( + overload: str, spec: typing.Callable[_InputT, typing.Any] + ) -> typing.Any: + assert overload not in specs + assert callable(spec) + assert overload.isidentifier() + specs[overload] = spec + return _exporter_context + + assert not hasattr(fn, "_define_overload") + _exporter_context._define_overload = _define_overload # type: ignore[attr-defined] + + return _exporter_context diff --git a/.venv/lib/python3.12/site-packages/torch/export/passes/__init__.py b/.venv/lib/python3.12/site-packages/torch/export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e1d21de660dcd3cd8d68c3f53738b1909429fb7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/export/passes/__init__.py @@ -0,0 +1,70 @@ +from typing import Union + +import torch +import torch.utils._pytree as pytree +from torch.export.exported_program import ExportedProgram + + +__all__ = ["move_to_device_pass"] + + +def move_to_device_pass( + ep: ExportedProgram, location: Union[torch.device, str, dict[str, str]] +) -> ExportedProgram: + """ + Move the exported program to the given device. + + Args: + ep (ExportedProgram): The exported program to move. + location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to. + If a string, it is interpreted as a device name. + If a dict, it is interpreted as a mapping from + the existing device to the intended one + + Returns: + ExportedProgram: The moved exported program. + """ + + def _get_new_device( + curr_device: torch.device, + location: Union[torch.device, str, dict[str, str]], + ) -> str: + if isinstance(location, dict): + if str(curr_device) in location.keys(): + return location[str(curr_device)] + else: + return str(curr_device) + else: + return str(location) + + # move all the state_dict + for k, v in ep.state_dict.items(): + if isinstance(v, torch.nn.Parameter): + ep._state_dict[k] = torch.nn.Parameter( + v.to(_get_new_device(v.device, location)), + v.requires_grad, + ) + else: + ep._state_dict[k] = v.to(_get_new_device(v.device, location)) + + # move all the constants + for k, v in ep.constants.items(): + if isinstance(v, torch.Tensor): + ep._constants[k] = v.to(_get_new_device(v.device, location)) + + for node in ep.graph.nodes: + # move all the nodes kwargs with burnt-in device + if "device" in node.kwargs: + kwargs = node.kwargs.copy() + kwargs["device"] = _get_new_device(kwargs["device"], location) + node.kwargs = kwargs + # move all the tensor metadata + node.meta["val"] = pytree.tree_map( + lambda v: v.to(_get_new_device(v.device, location)) + if isinstance(v, torch.Tensor) + else v, + node.meta.get("val"), + ) + + ep.validate() + return ep diff --git a/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2bf26a275d9eef91f4b6807ac472b2cd0c30b0f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/__init__.py @@ -0,0 +1,4 @@ +from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter + + +__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"] diff --git a/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py new file mode 100644 index 0000000000000000000000000000000000000000..7c97e6abe171ccb8a8f8acb5bd10f954d4bd93cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package.py @@ -0,0 +1,685 @@ +import glob +import io +import json +import logging +import os +import tempfile +import zipfile +from dataclasses import dataclass +from typing import Any, IO, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +import torch +import torch.utils._pytree as pytree +from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ExportedProgram +from torch.export.pt2_archive._package_weights import ( + get_complete, + group_weights, + Weights, +) +from torch.export.pt2_archive.constants import ( + AOTINDUCTOR_DIR, + ARCHIVE_FORMAT_PATH, + ARCHIVE_FORMAT_VALUE, + ARCHIVE_VERSION_PATH, + ARCHIVE_VERSION_VALUE, + CONSTANTS_DIR, + CUSTOM_OBJ_FILENAME_PREFIX, + EXTRA_DIR, + MODELS_DIR, + MODELS_FILENAME_FORMAT, + SAMPLE_INPUTS_FILENAME_FORMAT, + WEIGHT_FILENAME_PREFIX, + WEIGHTS_DIR, +) +from torch.types import FileLike + + +if TYPE_CHECKING: + from torch.utils._ordered_set import OrderedSet + + +DEFAULT_PICKLE_PROTOCOL = 2 +AOTI_FILES: TypeAlias = Union[ + list[Union[str, Weights]], dict[str, list[Union[str, Weights]]] +] + + +logger: logging.Logger = logging.getLogger(__name__) + + +def is_pt2_package(serialized_model: Union[bytes, str]) -> bool: + """ + Check if the serialized model is a PT2 Archive package. + """ + try: + zip_reader = zipfile.ZipFile( + io.BytesIO(serialized_model) + if isinstance(serialized_model, bytes) + else serialized_model + ) + root_folder = zip_reader.namelist()[0].split(os.path.sep)[0] + archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}" + if archive_format_path in zip_reader.namelist(): + return zip_reader.read(archive_format_path) == b"pt2" + except Exception as ex: + logger.info("Model is not a PT2 package: %s", str(ex)) + return False + + +class PT2ArchiveWriter: + """ + Context manager for writing a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] + # NOTICE: version here is different from the archive_version + # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version + # archive_version is the version of the PT2 archive spec, which write to /archive_version + self.archive_file.set_min_version(6) + + def __enter__(self) -> "PT2ArchiveWriter": + return self + + def __exit__(self, *args: Any) -> None: + if not self.has_record(ARCHIVE_FORMAT_PATH): + self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE) + + if not self.has_record(ARCHIVE_VERSION_PATH): + self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE) + + self.close() + + def has_record(self, name: str) -> bool: + """ + Check if a record exists in the archive. + """ + return name in self.archive_file.get_all_written_records() + + def count_prefix(self, prefix: str) -> int: + """ + Count the number of records that start with a given prefix. + """ + return sum( + 1 + for record in self.archive_file.get_all_written_records() + if record.startswith(prefix) + ) + + def write_bytes(self, name: str, data: bytes) -> None: + """ + Write a bytes object to the archive. + name: The destination file inside the archive. + data: The bytes object to write. + """ + assert isinstance(data, bytes), f"Expected bytes but got {type(data)}" + self.archive_file.write_record(name, data, len(data)) + + def write_string(self, name: str, data: str) -> None: + """ + Write a string object to the archive. + name: The destination file inside the archive. + data: The string object to write. + """ + assert isinstance(data, str), f"Expected string but got {type(data)}" + data_bytes = data.encode() + self.write_bytes(name, data_bytes) + + def write_file(self, name: str, file_path: str) -> None: + """ + Copy a file into the archive. + name: The destination file inside the archive. + file_path: The source file on disk. + """ + assert os.path.isfile(file_path), f"{file_path} is not a valid file path" + + with open(file_path, "rb") as f: + file_bytes = f.read() + self.write_bytes(name, file_bytes) + + def write_folder(self, archive_dir: str, folder_dir: str) -> None: + """ + Copy a folder into the archive. + archive_dir: The destination folder inside the archive. + folder_dir: The source folder on disk. + """ + assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path" + + file_paths = filter( + os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True) + ) + for file_path in file_paths: + filename = os.path.relpath(file_path, folder_dir) + archive_path = os.path.join(archive_dir, filename) + self.write_file(archive_path, file_path) + + def close(self) -> None: + """ + Close the archive. + """ + self.archive_file.write_end_of_file() + + +class PT2ArchiveReader: + """ + Context manager for reading a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] + assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( + "Invalid archive format" + ) + + def __enter__(self) -> "PT2ArchiveReader": + return self + + def __exit__(self, *args: Any) -> None: + # torch._C.PyTorchFileReader doesn't have a close method + pass + + def read_bytes(self, name: str) -> bytes: + """ + Read a bytes object from the archive. + name: The source file inside the archive. + """ + return self.archive_file.get_record(name) + + def read_string(self, name: str) -> str: + """ + Read a string object from the archive. + name: The source file inside the archive. + """ + data = self.read_bytes(name) + return data.decode() + + def archive_version(self) -> int: + """ + Get the archive version. + """ + try: + archive_version = self.read_string(ARCHIVE_VERSION_PATH) + except Exception: + # if archive_version is not found, it means the archive is older than version 0. + # In this case, we assume the archive is version 0. + archive_version = "0" + + return int(archive_version) + + def get_file_names(self) -> list[str]: + """ + Get the file names in the archive. + """ + return self.archive_file.get_all_records() + + +def _package_aoti_files( + archive_writer: PT2ArchiveWriter, + aoti_files: Optional[AOTI_FILES], + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if aoti_files is None: + return + + if isinstance(aoti_files, list): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + + all_weights: dict[str, Weights] = {} # model_name -> weight + weights_configs: dict[ + str, dict[str, Any] + ] = {} # model_name -> (weight_name -> (filename, shape, stride, offset)) + + for model_name, files in aoti_files.items(): + num_so_files = 0 + weights_configs[model_name] = {} + + for file in files: + if file == "": + continue + + if isinstance(file, Weights): + all_weights[model_name] = file + continue + + if file.endswith(".so"): + num_so_files += 1 + if num_so_files > 1: + raise RuntimeError( + f"Multiple .so files found in {files}. " + "You might need to clear your cache " + "directory before calling aoti_compile again." + ) + + filename = os.path.basename(file) + if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX): + new_filepath = os.path.join(CONSTANTS_DIR, filename) + else: + new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename) + logger.debug( + "Saving AOTI generated file %s to archive in %s", file, new_filepath + ) + archive_writer.write_file( + str(new_filepath), + file, + ) + + if len(all_weights) > 0: + # Dedup weights + grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights) + for idx, group in enumerate(grouped_tensors): + filename = f"{WEIGHT_FILENAME_PREFIX}{idx}" + model_name, weight_name = get_complete(group, all_weights) + complete_tensor, _ = all_weights[model_name].get_weight(weight_name) + buffer = io.BytesIO() + torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol) + archive_writer.write_bytes( + os.path.join(WEIGHTS_DIR, filename), buffer.getvalue() + ) + for model_name, weight_name in group: + _, w_property = all_weights[model_name].get_weight(weight_name) + weights_configs[model_name][weight_name] = ( + filename, + w_property.shape, + w_property.stride, + w_property.offset, + ) + + for model_name, weights_config in weights_configs.items(): + archive_writer.write_string( + os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"), + json.dumps(weights_config), + ) + logger.debug("packaging weights_config for model %s", model_name) + logger.debug(weights_config) + + +def _package_exported_programs( + archive_writer: PT2ArchiveWriter, + exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]], + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if exported_programs is None: + return + + if isinstance(exported_programs, ExportedProgram): + exported_programs = {"model", exported_programs} # type: ignore[assignment] + + assert isinstance(exported_programs, dict) + + for model_name, ep in exported_programs.items(): + artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol) + + archive_writer.write_bytes( + MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program + ) + # TODO:Consider dedup this with the weights saved in package_aoti_files + archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict) + archive_writer.write_bytes( + f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants + ) + archive_writer.write_bytes( + SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name), + artifact.example_inputs, + ) + + +def _package_extra_files( + archive_writer: PT2ArchiveWriter, extra_files: Optional[dict[str, Any]] +) -> None: + if extra_files is None: + return + + for extra_file_name, content in extra_files.items(): + archive_writer.write_string(f"{EXTRA_DIR}{extra_file_name}", content) + + +def package_pt2( + f: FileLike, + *, + exported_programs: Optional[ + Union[ExportedProgram, dict[str, ExportedProgram]] + ] = None, + aoti_files: Optional[AOTI_FILES] = None, + extra_files: Optional[dict[str, Any]] = None, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> FileLike: + """ + Saves the artifacts to a PT2Archive format + (https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a). + The artifact can then be loaded using ``load_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to + implement write and flush) or a string containing a file name. + + exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): + The exported program to save, or a dictionary mapping model name to an + exported program to save. The exported program will be saved under + models/*.json. If only one ExportedProgram is specified, this will + automatically be named "model". + + aoti_files (Union[list[str], dict[str, list[str]]): A list of files + generated by AOTInductor via + ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, + or a dictionary mapping model name to its AOTInductor generated files. + If only one set of files is specified, this will automatically be named + "model". + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of the pt2. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + pickle_protocol: can be specified to override the default protocol + + """ + assert not ( + exported_programs is None and aoti_files is None and extra_files is None + ), ( + "No value passed in for `exported_programs`, `aoti_files`, and " + "`extra_files`, implying that you do not plan on saving anything." + ) + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error + logger.warning( + "Expect archive file to be a file ending in .pt2, or is a buffer. " + "Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with PT2ArchiveWriter(f) as archive_writer: + _package_exported_programs( + archive_writer, exported_programs, pickle_protocol=pickle_protocol + ) + _package_aoti_files( + archive_writer, + aoti_files, + pickle_protocol=pickle_protocol, + ) + _package_extra_files(archive_writer, extra_files) + + if isinstance(f, (io.IOBase, IO)): + f.seek(0) + return f + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 + """ + + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader + + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = self.loader.boxed_run(flat_inputs) + return pytree.tree_unflatten(flat_outputs, out_spec) + + def get_metadata(self) -> dict[str, str]: + return self.loader.get_metadata() + + def load_constants( + self, + constants_map: dict[str, torch.Tensor], + *, + check_full_update: bool, + user_managed: bool = False, + ) -> None: + """ + Given a mapping of constant fqns to tensors, load the constants into the model. + You can use ``get_constant_fqns`` to get the list of constant fqns that + are needed in the compiled model. + + Args: + constants_map: A mapping of constant fqns to tensors. + check_full_update: Whether to add check to see if all the constants + are updated and have values. + """ + self.loader.load_constants( + constants_map, False, check_full_update, user_managed + ) + + def get_constant_fqns(self) -> list[str]: + return self.loader.get_constant_fqns() + + def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel": + logger.warning( + "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." + ) + return AOTICompiledModel(self.loader) + + +@dataclass +class PT2ArchiveContents: + exported_programs: dict[str, ExportedProgram] + aoti_runners: dict[str, AOTICompiledModel] + extra_files: dict[str, Any] + + +def _load_exported_programs( + archive_reader: PT2ArchiveReader, + file_names: list[str], + expected_opset_version: Optional[dict[str, int]], +) -> dict[str, ExportedProgram]: + exported_program_files = [ + file for file in file_names if file.startswith(MODELS_DIR) + ] + exported_programs = {} + for file in exported_program_files: + prefix, suffix = MODELS_FILENAME_FORMAT.split( + "{}" + ) # split "models/{}.json" into "models/" and "json" + model_name = file[ + len(prefix) : -len(suffix) + ] # given "models/foo.json" we can now get "foo" + + weights_file = f"{WEIGHTS_DIR}{model_name}.pt" + constants_file = f"{CONSTANTS_DIR}{model_name}.pt" + sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) + + serialized_exported_program = archive_reader.read_bytes(file) + serialized_weights = archive_reader.read_bytes(weights_file) + serialized_constants = archive_reader.read_bytes(constants_file) + serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) + + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_weights, + serialized_constants, + serialized_sample_inputs, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + exported_programs[model_name] = ep + + return exported_programs + + +def _load_extra_files( + archive_reader: PT2ArchiveReader, file_names: list[str] +) -> dict[str, Any]: + extra_files = [file for file in file_names if file.startswith(EXTRA_DIR)] + + extra_file_contents: dict[str, Any] = {} + for file in extra_files: + contents = archive_reader.read_string(file) + extra_file_contents[file[len(EXTRA_DIR) :]] = contents + + return extra_file_contents + + +def load_pt2( + f: FileLike, + *, + expected_opset_version: Optional[dict[str, int]] = None, + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, + load_weights_from_disk: bool = False, +) -> PT2ArchiveContents: # type: ignore[type-arg] + """ + Loads all the artifacts previously saved with ``package_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + num_runners (int): Number of runners to load AOTInductor artifacts + + run_single_threaded (bool): Whether the model should be run without + thread synchronization logic. This is useful to avoid conflicts with + CUDAGraphs. + + device_index (int): The index of the device to which the PT2 package is + to be loaded. By default, `device_index=-1` is used, which corresponds + to the device `cuda` when using CUDA. Passing `device_index=1` would + load the package to `cuda:1`, for example. + + Returns: + A ``PT2ArchiveContents`` object which contains all the objects in the PT2. + """ + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error in 2.9 + logger.warning( + "Unable to load package. f must be a buffer or a file ending in " + ".pt2. Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + weights = {} + weight_maps = {} + with PT2ArchiveReader(f) as archive_reader: + version = archive_reader.read_string(ARCHIVE_VERSION_PATH) + if version != ARCHIVE_VERSION_VALUE: + raise ValueError( + f"Saved archive version {version} does not match our current " + f"archive version {ARCHIVE_VERSION_VALUE}." + ) + + file_names = archive_reader.get_file_names() + + exported_programs = _load_exported_programs( + archive_reader, file_names, expected_opset_version + ) + extra_files = _load_extra_files(archive_reader, file_names) + + # Get a list of AOTI model names + aoti_model_names: set[str] = set() + for file in file_names: + if file.startswith(AOTINDUCTOR_DIR): + file_end = file[ + len(AOTINDUCTOR_DIR) : + ] # remove data/aotinductor/ prefix + model_name = file_end.split("/")[ + 0 + ] # split "model_name/...cpp" into "model_name" + aoti_model_names.add(model_name) + if load_weights_from_disk and file.endswith("weights_config.json"): + weight_map = json.loads(archive_reader.read_string(file)) + weight_maps[model_name] = weight_map + elif load_weights_from_disk and file.startswith(WEIGHTS_DIR): + weight_file_name = file[ + len(WEIGHTS_DIR) : + ] # remove data/weights/ prefix + weight_bytes = archive_reader.read_bytes(file) + loaded_weight = torch.load(io.BytesIO(weight_bytes)) + weights[weight_file_name] = loaded_weight + + if isinstance(f, (io.IOBase, IO)): + if len(aoti_model_names) > 0: + # Workaround for AOTIModelPackageLoader not reading buffers + with tempfile.NamedTemporaryFile(suffix=".pt2") as tf: + f.seek(0) + tf.write(f.read()) + f.seek(0) + logger.debug("Writing buffer to tmp file located at %s.", tf.name) + + aoti_runners = { + model_name: AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + tf.name, + model_name, + run_single_threaded, + num_runners, + device_index, + ) + ) + for model_name in aoti_model_names + } + else: + aoti_runners = {} + else: + aoti_runners = { + model_name: AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + f, model_name, run_single_threaded, num_runners, device_index + ) + ) + for model_name in aoti_model_names + } + + if weight_maps: + for model_name in aoti_model_names: + model_weights = {} + for weight_name, (file, shape, stride, storage_offset) in weight_maps[ + model_name + ].items(): + weight = weights[file] + model_weights[weight_name] = weight.as_strided( + shape, stride, storage_offset + ) + + # user_managed=True ensures the weights updates are shared by all runners. + aoti_runners[model_name].load_constants( + model_weights, check_full_update=True, user_managed=True + ) + + return PT2ArchiveContents(exported_programs, aoti_runners, extra_files) + + +def load_weights_to_pt2_contents( + pt2_contents: PT2ArchiveContents, weights_map: dict[str, Any] +) -> None: + """ + Load weights into the models in PT2 archive contents + + Args: + pt2_contents (PT2ArchiveContents): The contents of the PT2 archive. + """ + for model_name, weights in weights_map.items(): + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in PT2 archive contents.") + pt2_contents.aoti_runners[model_name].load_constants( + weights, check_full_update=True, user_managed=True + ) diff --git a/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..e6721ea9229a6bcbc3a6a7abf37d2f62f49d8eb8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/_package_weights.py @@ -0,0 +1,101 @@ +import collections + +import torch +from torch.utils._ordered_set import OrderedSet + + +def _end_ptr(tensor: torch.Tensor) -> int: + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +class TensorProperties: + def __init__(self, tensor: torch.Tensor): + # info about underlying storage + self.storage_ptr = tensor.untyped_storage().data_ptr() + self.storage_size = tensor.untyped_storage().nbytes() + + # info to recover tensor + self.shape = tensor.shape + self.stride = tensor.stride() + self.offset = tensor.storage_offset() + + self.start = tensor.data_ptr() + self.end = _end_ptr(tensor) + + def is_complete(self) -> bool: + """ + Whehter the tensor completely overlaps with its underlying storage + """ + return ( + self.start == self.storage_ptr + and self.end == self.storage_ptr + self.storage_size + ) + + +class Weights(dict): + """ + A dictionary mapping from weight name to a tuple of (tensor, TensorProperties). + tensor represents the actual intial value of the weight. + TensorProperties represents the properties of the weight that are needed to recover the weight. + + We use two separate entries because `tensor` could be a clone of the original weight tensor, + so it doesn't have the same property as the original weight (such as underlying storage pointer). + """ + + def __init__(self, weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]]): + super().__init__(weight_dict) + + def get_weight(self, name: str) -> tuple[torch.Tensor, TensorProperties]: + return self[name] + + def get_weight_properties(self, name: str) -> TensorProperties: + return self[name][1] + + +def get_complete( + group: OrderedSet[tuple[str, str]], models_weights: dict[str, Weights] +) -> tuple[str, str]: + """ + `group` is a (model_name, weight_name) tuple. + `model_weights` is a dictionary mapping from model name to its Weights. + + One of the tensor in `group` must be complete and they must share the + same underlying storage. + + Returns the name of the complete tensor in the `group`. If multiple + tensors are complete, returns an arbitrary one. + """ + + def get_tensor_properties(name_tuple: tuple[str, str]) -> TensorProperties: + # returns the tensor properties + (model_name, weight_name) = name_tuple + return models_weights[model_name].get_weight_properties(weight_name) + + for name_tuple in group: + tensor_property = get_tensor_properties(name_tuple) + if tensor_property.is_complete(): + return name_tuple + + raise RuntimeError("No complete tensor found in the group!") + + +def group_weights(all_weights: dict[str, Weights]) -> list[OrderedSet[tuple[str, str]]]: + """ + Group weights that share the same underlying storage. + + Returns a list of sets, each set contains a tuple of (model_name, weight_name). + """ + + weights_dict: dict[int, OrderedSet[tuple[str, str]]] = collections.defaultdict( + OrderedSet + ) # storage_key -> set(weight) + + for model_name, weights in all_weights.items(): + for weight_name, (_, properties) in weights.items(): + weights_dict[properties.storage_ptr].add((model_name, weight_name)) + + return list(weights_dict.values()) diff --git a/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..3fbf9c69fc1bae05e3b653f9ceaa410126666a2a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/export/pt2_archive/constants.py @@ -0,0 +1,28 @@ +# Defined in torch/csrc/export/pt2_archive_constants.h +from torch._C._export import pt2_archive_constants + + +AOTINDUCTOR_DIR: str = pt2_archive_constants.AOTINDUCTOR_DIR +ARCHIVE_FORMAT_PATH: str = pt2_archive_constants.ARCHIVE_FORMAT_PATH +ARCHIVE_FORMAT_VALUE: str = pt2_archive_constants.ARCHIVE_FORMAT_VALUE +ARCHIVE_ROOT_NAME: str = pt2_archive_constants.ARCHIVE_ROOT_NAME +ARCHIVE_VERSION_PATH: str = pt2_archive_constants.ARCHIVE_VERSION_PATH +ARCHIVE_VERSION_VALUE: str = pt2_archive_constants.ARCHIVE_VERSION_VALUE +CONSTANTS_DIR: str = pt2_archive_constants.CONSTANTS_DIR +CUSTOM_OBJ_FILENAME_PREFIX: str = pt2_archive_constants.CUSTOM_OBJ_FILENAME_PREFIX +EXTRA_DIR: str = pt2_archive_constants.EXTRA_DIR +MODELS_DIR: str = pt2_archive_constants.MODELS_DIR +MODELS_FILENAME_FORMAT: str = pt2_archive_constants.MODELS_FILENAME_FORMAT +MODULE_INFO_PATH: str = pt2_archive_constants.MODULE_INFO_PATH +MTIA_DIR: str = pt2_archive_constants.MTIA_DIR +SAMPLE_INPUTS_DIR: str = pt2_archive_constants.SAMPLE_INPUTS_DIR +SAMPLE_INPUTS_FILENAME_FORMAT: str = pt2_archive_constants.SAMPLE_INPUTS_FILENAME_FORMAT +TENSOR_CONSTANT_FILENAME_PREFIX: str = ( + pt2_archive_constants.TENSOR_CONSTANT_FILENAME_PREFIX +) +WEIGHT_FILENAME_PREFIX: str = pt2_archive_constants.WEIGHT_FILENAME_PREFIX +WEIGHTS_DIR: str = pt2_archive_constants.WEIGHTS_DIR +XL_MODEL_WEIGHTS_DIR: str = pt2_archive_constants.XL_MODEL_WEIGHTS_DIR +XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ( + pt2_archive_constants.XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH +) diff --git a/.venv/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc0a7ed327de4464955fac6a308691c5e3418336 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fft/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/futures/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/futures/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f19b783e7cdf271c4f071ec9cf69f579e93b008f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/futures/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6afbaf306ec142e1f32e62ff0fdbbf352257737 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c05d3116c2bb5b888527dd191661a91c82d6e0cf Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_compatibility.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21f0de99497dfdd1f5d8f0d8f130f64a0e417017 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..555a9b02b6114819a38b07450539db2bfe6110dc Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_pytree.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cb0b36892e09f702983ee77d676109fb339989b Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc096b44c6ec4a1b76e94ea427d86f9fc799e2cf Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42f0f256691fd13db6114ffbf39536d68f2355d9 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/config.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ac1e23ff9bdf35ceb60945a291fd892fd6ce344 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3311d9c3f77d4f9c3ac7da6ff8d700df8faee08d Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/graph_module.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5cbd4601da8a6d3fd989534bdcb9511bc9bdbd1 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/immutable_collections.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d50b3c87ce4e2f2eb99604e9c4f0cd7927cac947 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/interpreter.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6f2960414e6d518fd5396c55021ed0aa5f7cd10 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/node.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..415c8f5334ba80ac724821398a7273fade0b3f25 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/operator_schemas.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c006a5c07be1800019695a86b92cf400b74af2a0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/proxy.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..697af088afcada2f739e982ab26abb124ef43ce0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2234a9f2c8e8c4b297a42f7c03fd848226633fdb Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/__pycache__/traceback.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fd66bb0b73bd21bdffd9bf35b1b0689c36bcb19 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a1fb566e17f482cbd421e392813f2d45ea40b8 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6b92f33d4ac4b9e24dc7584fb63f36813d1635a Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_config.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8b8eaa9a6b3410fa8c6c3c7799d8287546cbdf4 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91ae9c122919a37b5c02d9b17e2818ddf70a0fc8 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fb6886285be9a4e1e1dc324ab902d77bae5f4d0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d57e482bbbd354d9c0d28eeb47ae1f04fa4f071 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/recording.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b85dc1e0889051f4b123e36a502fabaff85cff82 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py new file mode 100644 index 0000000000000000000000000000000000000000..9c742431857c33af22dbc1ad73b5bdfcf6124b9c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_backward_state.py @@ -0,0 +1,27 @@ +import torch.fx + + +class BackwardState: + """ + BackwardState is used to pass Python hooks from the forwards pass + into the backwards pass in Dynamo+Compiled Autograd. + + It is created by TorchDynamo and has special handling there. + Dynamo will pass an empty BackwardState to the forwards, then populate + members on it (via setattr) only after the forwards graph is finished. + Later on, in CompileAutograd we will inline and add the needed guards + on the BackwardState. + + BackwardState is identified and has special handling in AOTAutograd. + During AOTAutograd: + 1) BackwardState is an input to the forwards graph + 2) It must only be used in the backwards + 3) It will be empty in the forwards + 4) In the forwards we add a wrapper to save it + 5) In the backwards it becomes an input + 6) There can only be one per graph + + BackwardState requires CompiledAutograd. + """ + + proxy: torch.fx.Proxy diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/_config.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4296b6410c9bd4d4198267ff64d66125c494e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_config.py @@ -0,0 +1,106 @@ +import os +import sys +from typing import Optional + + +# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors. +no_data_dependent_graph_break = ( + os.environ.get("TORCHDYNAMO_NO_DATA_DEPENDENT_GRAPH_BREAK", "0") == "1" +) +# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations. +translation_validation = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1" +) +# Timeout (in milliseconds) for z3 finding a solution. +# [@compile_ignored: debug] +translation_validation_timeout = int( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000") +) +# Disables bisection for translation validation. +# +# Translation validation bisection is enabled by default, if translation validation +# is also enabled. This should help finding guard simplification issues. However, +# since validation uses Z3 for bisecting, it might take a lot of time. +# +# Set this configuration option so as to avoid bisecting. +# [@compile_ignored: debug] +translation_validation_no_bisect = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1" +) +# Checks whether replaying ShapeEnv events on a freshly constructed one yields +# the a ShapeEnv with the same state. This should be used only in testing. +check_shape_env_recorded_events = False + +# TODO: Perhaps consider allowing unions for the configs below (so you can hit +# multiple reps at the same time) + +# Give extended debug information if the string representation of a guard +# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue +# this guard, we will generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_guard_added = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None +) + +# Give extended debug information when a particular symbol is allocated. For +# example, set this to "u2" and whenever we create this symbol, we will +# generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_create_symbol = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None +) + +# Give extended debug information (C++ backtrace) for all extended debug +# settings as well as errors. The C++ backtrace is slow and very spammy so we +# don't include it by default even when you're requesting extended debug. +# [@compile_ignored: debug] +extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != "" + +# Give extended debug information (line of code) when a torch function +# is called during export. This is useful for showing progress and detecting +# where export might be stuck. Currently only works for strict=False. +# [@compile_ignored: debug] +extended_debug_current_loc = ( + os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1" +) + +# [@compile_ignored: debug] Show a warning for every specialization +print_specializations = False + +# wraps (un)equalities with 'Not' class after recording the correct expression +# in the FX graph. This should incorrectly construct the divisible and replacement +# lists, and incorrectly issue guards. +inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False + +# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly +validate_shape_env_version_key = False + +# If we produce more than this many guards on a symbol, force the symbol to +# get specialized and bail out if this many guards mention this particular +# symbol. This may be slightly more aggressive than the true number of guards +# issued (as we test if we've hit the limit on-the-fly, whereas we may +# do further simplifications at final guard issuance time that make guards +# irrelevant.) +symbol_guard_limit_before_specialize: Optional[int] = None + +# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same. +use_duck_shape = True + +# Controls the registration of torch.nonzero() on the meta device. +# When True, nonzero returns a tensor with shape (self.numel(), self.dim()) +# assuming all elements are none-zero. +# Default is False to prevent unintended registration. Set to True to enable. +meta_nonzero_assume_all_nonzero = False + +# Applies size-oblivious reasoning to backed symbols. This allocates a [0, inf] range for backed size symbols, +# and relies on size-oblivious semantics to avoid 0/1 specialization guards by marking them size-like. +# Currently an experimental option for export. +backed_size_oblivious = False + +# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking. +skip_dtype_check_in_meta_registrations = False + +from torch.utils._config_module import install_config_module + + +install_config_module(sys.modules[__name__]) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py new file mode 100644 index 0000000000000000000000000000000000000000..c45728d24d1ddba4f315f8cfd13b7827b4fbbb16 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_constant_symnode.py @@ -0,0 +1,69 @@ +from typing import * # noqa: F403 + + +# Python version of c10/core/ConstantSymNodeImpl.cpp +# This needs to exist because the Python version of nested int is not compatible +# with the C++ version of constant symnode. +class ConstantIntNode: + def __init__(self, val: int): + self.val = val + + def is_constant(self) -> bool: + return True + + def maybe_as_int(self) -> int: + return self.val + + def is_int(self) -> bool: + return True + + def is_float(self) -> bool: + return False + + def is_bool(self) -> bool: + return False + + def is_nested_int(self) -> bool: + return False + + def clone(self) -> "ConstantIntNode": + return self + + def _str(self) -> str: + return str(self.val) + + def __str__(self) -> str: + return self._str() + + def __repr__(self) -> str: + return self._str() + + def _graph_repr(self) -> str: + return self._str() + + def mul(self, other: Any) -> Any: + return other.mul(self) + + def eq(self, other: Any) -> Any: + return other.eq(self) + + def ne(self, other: Any) -> Any: + return other.ne(self) + + def gt(self, other: Any) -> Any: + return other.lt(self) + + def lt(self, other: Any) -> Any: + return other.gt(self) + + def le(self, other: Any) -> Any: + return other.ge(self) + + def ge(self, other: Any) -> Any: + return other.le(self) + + def is_symbolic(self) -> bool: + return False + + def constant_int(self) -> int: + return self.val diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py new file mode 100644 index 0000000000000000000000000000000000000000..4828b6f458eb40dfcd551fcc3d3cf0e25e8694a9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/_dynamism.py @@ -0,0 +1,118 @@ +import re +from typing import Any, Callable, Union + +import torch +from torch.utils._pytree import tree_flatten_with_path, tree_map + + +KeyPath = tuple[Any, ...] +NonTensorShapeFn = Callable[[Union[int, float]], tuple[Any, ...]] + +__all__ = [ + "normalize_source_name", + "module_to_nested_dict", + "track_dynamism_across_examples", + "clone_and_convert_to_meta", +] + + +def normalize_source_name(name: str) -> str: + # Match attribute access like .x and replace with ['x'] + return re.sub(r"\.([a-zA-Z_][a-zA-Z0-9_]*)", r"['\1']", name) + + +def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]: + """Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys.""" + self_dict: dict[str, Any] = {} + + self_dict["_parameters"] = {} + self_dict["_modules"] = {} + + for attr_name in dir(module): + try: + if not attr_name.startswith("_") and not callable( + getattr(module, attr_name) + ): + attr_value = getattr(module, attr_name) + if ( + not isinstance(attr_value, torch.nn.Module) + and isinstance(attr_value, (int, float, torch.Tensor)) + and type(attr_value) is not bool + ): + self_dict[attr_name] = attr_value + except NotImplementedError: + # Skip attributes that raise NotImplementedError since they won't + # contain any dynamism anyways. + continue + + for name, param in module.named_parameters(recurse=False): + self_dict["_parameters"][name] = param + for name, buffer in module.named_buffers(recurse=False): + self_dict["_parameters"][name] = buffer + + for name, submodule in module.named_children(): + self_dict["_modules"][name] = module_to_nested_dict(submodule) + + return self_dict + + +def track_dynamism_across_examples( + example_inputs: list[Any], +) -> dict[Any, Any]: + """ + This function analyzes a list of example inputs to determine the dynamism of their shapes. + It tracks whether the dimensions of tensors or non-tensor values change across + different examples. The function returns a dictionary where each key represents + a path to a value in the input examples, and the corresponding value is a tuple + indicating which dimensions are dynamic (i.e., change across examples). This + helps in understanding how the structure of data varies across different instances. + """ + tracking: dict[KeyPath, tuple[list[set[Any]], bool]] = {} + + for ex in example_inputs: + if "self" in ex and isinstance(ex["self"], torch.nn.Module): + ex["self"] = module_to_nested_dict(ex["self"]) + leaves_with_paths, _ = tree_flatten_with_path(ex) + for key_path, value in leaves_with_paths: + if not isinstance(value, (int, float, torch.Tensor)): + continue + if isinstance(value, torch.Tensor): + shape: tuple[int | float, ...] = tuple(value.shape) + is_tensor = True + else: + shape = (value,) + is_tensor = False + if key_path not in tracking: + tracking[key_path] = ([set() for _ in range(len(shape))], is_tensor) + else: + dim_sets, flag = tracking[key_path] + if flag != is_tensor: + pass + while len(dim_sets) < len(shape): + dim_sets.append(set()) + for i, dim in enumerate(shape): + tracking[key_path][0][i].add(dim) + + output: dict[Any, Any] = {} + for key_path, (dim_sets, _is_tensor) in tracking.items(): + final_dyn = tuple(len(s) > 1 for s in dim_sets) + key_str = "L" + "".join(f"{str(k)}" for k in key_path) + key = key_path[0].key # type: ignore[attr-defined] + if key not in output: + output[key] = {} + output[key][key_str] = final_dyn + return output + + +def clone_and_convert_to_meta(example_input: Any) -> Any: + """ + This function takes a list of example inputs and for each tensor, clones it and converts it to device=meta. + For non-tensor values, it keeps the reference. It uses pytree to handle nested structures recursively. + """ + + def transform_fn(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.clone().to(device="meta") + return value + + return tree_map(transform_fn, example_input) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..c29d05f511a79521fa0f7db8478ac21100b6f6fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/accelerator_partitioner.py @@ -0,0 +1,1080 @@ +# mypy: allow-untyped-defs +import operator +from collections import deque +from typing import NamedTuple + +import torch +from torch.fx.experimental.partitioner_utils import ( + Device, + get_extra_size_of, + get_latency_of_partitioned_graph, + get_partition_to_latency_mapping, + NodeLatency, + Partition, + PartitionerConfig, + PartitionMode, +) +from torch.fx.graph_module import GraphModule +from torch.fx.node import map_arg, Node +from torch.fx.passes.graph_manipulation import get_size_of_all_nodes +from torch.fx.passes.split_module import split_module + + +class DAGNode: + """DAGNode class maintains useful information for a partition (submodule), + and its input submodules and output submodules. + """ + + def __init__( + self, + submodule_node: Node, + input_nodes: list[Node], + output_nodes: list[Node], + logical_device_ids: list[int], + size_bytes: int, + ) -> None: + self.submodule_node: Node = submodule_node + self.input_nodes: list[Node] = input_nodes + self.output_nodes: list[Node] = output_nodes + self.logical_device_ids: list[int] = logical_device_ids + self.size_bytes = size_bytes + + def __str__(self) -> str: + return str(self.submodule_node) + + +class DAG: + """DAG class contains all the DAG nodes""" + + def __init__(self) -> None: + self.nodes: list[DAGNode] = [] + + def create_node( + self, + submodule_node: Node, + input_nodes: list[Node], + output_nodes: list[Node], + logical_devices: list[int], + size_bytes: int, + ) -> None: + node = DAGNode( + submodule_node, input_nodes, output_nodes, logical_devices, size_bytes + ) + self.nodes.append(node) + + +class PartitionResult(NamedTuple): + """NameTuple used for returning DAG and a new fx module""" + + dag: DAG + module_with_submodules: GraphModule + + +"""Followings are some helper functions for partition manipulation""" + + +def reset_partition_device(partitions): + for partition in partitions: + partition.logical_device_ids = [] + + +def combine_two_partitions( + partition_0: Partition, partition_1: Partition, partitions: list[Partition] +) -> None: + """Given a list of partitions and its two partitions, + combine these two partitions into a new one appending to the partitions + and remove the previous two partitions from the list of partitions + """ + partition = Partition(len(partitions)) + partition.nodes = partition_0.nodes.union(partition_1.nodes) + partition.recalculate_mem_size() + partitions.append(partition) + partitions.remove(partition_0) + partitions.remove(partition_1) + reorganize_partitions(partitions) + return + + +def set_parents_and_children(partitions: list[Partition]) -> None: + """Given a list of partitions, mark parents and children for each partition""" + # Go through all nodes in a partition. + # If a node's user is in other partition, + # then the other partition is this partition's children. + # This partition is the other partition's parent + for partition in partitions: + partition.children = set() + partition.parents = set() + for partition in partitions: + for node in partition.nodes: + # For each node in the current partition, find its users + users = node.users + for n in users: + # Find which the partition the user node belongs to. + # Note that if the node itself is also belongs to that partition, + # that partition is not the child of the current partition + for p in partitions: + if p != partition and n in p.nodes and node not in p.nodes: + partition.children.add(p) + p.parents.add(partition) + return + + +def reorganize_partitions(partitions: list[Partition]) -> None: + """Given a list of partitions, reorganize partition id, + its parents and its children for each partition + """ + # Rearrange partition ids + for i, partition in enumerate(partitions): + partition.partition_id = i + set_parents_and_children(partitions) + return + + +def get_bfs_level_partition(partitions: list[Partition]) -> None: + """Given a list of partitions, + mark the bfs level for each partition + """ + current_level: set[Partition] = set() + visited: set[Partition] = set() + for partition in partitions: + # If a partition has no parent, it should be in root level + if len(partition.parents) == 0: + current_level.add(partition) + next_level: set[Partition] = set() + level = 0 + # bfs + while current_level: + partition = current_level.pop() + partition.bfs_level = level + visited.add(partition) + children = partition.children + for child in children: + if child not in next_level: + next_level.add(child) + if not current_level: + current_level = next_level.copy() + next_level = set() + level += 1 + return + + +def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]: + """Given a list of partitions,return node to partition mapping""" + node_to_partition: dict[Node, int] = {} + for partition in partitions: + for node in partition.nodes: + node_to_partition[node] = partition.partition_id + return node_to_partition + + +def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]: + """Get a mapping from device logical ID to Device object.""" + logical_id_to_device: dict[int, Device] = {} + for d in devices: + logical_id_to_device[d.logical_id] = d + return logical_id_to_device + + +def get_device_partition_stats( + partitions: list[Partition], devices: list[Device] +) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]: + """Given a list of partitions and a list of devices, returns: + 1. A mapping from device to partitions on it; + 2. A mapping from device to its remaining memory size; + 3. A list of partitions that do not have a device. + """ + # logical id to device + logical_id_to_device = get_logical_id_to_device(devices) + # Track partitions on device + device_to_partitions: dict[Device, list[Partition]] = {} + # Track device's left mem size + device_to_left_mem_bytes: dict[Device, int] = {} + for d in devices: + device_to_partitions[d] = [] + device_to_left_mem_bytes[d] = d.available_mem_bytes + + # Deal with the partitions that already have a device + # and also collect all partitions without a device (no_device_partitions) + no_device_partitions = [] + for partition in partitions: + if partition.logical_device_ids != []: + for logical_id in partition.logical_device_ids: + device = logical_id_to_device[logical_id] + device_to_partitions[device].append(partition) + device_to_left_mem_bytes[device] -= partition.used_mem_bytes + else: + no_device_partitions.append(partition) + + return ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) + + +def get_device_to_partitions_mapping( + partitions: list[Partition], devices: list[Device] +): + """Given a list of partitions and a list of devices, + map each partition into a device. + """ + + def calculate_extra_mem_bytes_needed_for( + partition: Partition, partitions: list[Partition] + ): + all_nodes: set[Node] = set() + for p in partitions: + all_nodes = all_nodes.union(p.nodes) + if len(all_nodes) == 0: + return partition.used_mem_bytes + all_nodes = all_nodes.union(partition.nodes) + extra_size_needed = 0 + for node in partition.nodes: + extra_size_needed += get_extra_size_of(node, all_nodes) + return extra_size_needed + + def find_device_for(partition: Partition): + """Given a partition, find a logical device for the partition + The algorithm is to put the partition on the device + that has just enough mem left for that partition. + device_to_left_mem_bytes is a dictionary between device and its left mem size + sorted by its left mem size + """ + for d in device_to_left_mem_bytes: + extra_size_needed = calculate_extra_mem_bytes_needed_for( + partition, device_to_partitions[d] + ) + if extra_size_needed < device_to_left_mem_bytes[d]: + device_to_partitions[d].append(partition) + partition.logical_device_ids.append(d.logical_id) + device_to_left_mem_bytes[d] -= extra_size_needed + return True + return False + + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(partitions, devices) + + # Find devices for all the partitions without a device + found_device = True + for partition in no_device_partitions: + device_to_left_mem_bytes = dict( + sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)) + ) + found_device = find_device_for(partition) + if not found_device: + break + return found_device + + +def check_dependency(partition): + """Given a partition,check if there is a circular dependency on + this partition using bfs + """ + visited: set[Partition] = {partition} + queue: deque[Partition] = deque([partition]) + while queue: + p = queue.popleft() + for child in p.children: + if child == partition: + return True + else: + if child not in visited: + visited.add(child) + queue.append(child) + return False + + +class Partitioner: + """A fx module may not fit into one device. + Partitioner class helps partition one fx module into submodules (partitions), + so that the submodules can be executed crossing different accelerators. + The main function of this class is self.partition_graph. + It partitions the fx module based on the scheme specified in partition_config + A DAG structure is returned + along with a new fx module with submodule nodes. + """ + + def __init__(self) -> None: + self.partitions: list[Partition] = [] + self.node_to_partition: dict[Node, int] = {} + self.devices: list[Device] = [] + + def partition_graph( + self, + fx_module: GraphModule, + torch_module: torch.nn.Module, + partitioner_config: PartitionerConfig, + ) -> PartitionResult: + """Given the fx module, torch module and partitioner_config, + find the partitions, do the partitions, + and then return a DAG and a new fx module with submodule nodes (partitions) + """ + self.graph_module = fx_module + self.torch_module = torch_module + self.devices = partitioner_config.devices + if len(self.devices) == 0: + raise RuntimeError("No devices") + # Tag the size in bytes to all nodes in the graph_module. + get_size_of_all_nodes(self.graph_module) + # Check if there are op nodes in the fx module + nodes = self.graph_module.graph.nodes + if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): + raise RuntimeError("No Partition since no operations in the module") + # Calculate total size of the fx module + total_size_of_graph = 0 + for node in nodes: + if node.op == "output": + break + total_size_of_graph += node.size_bytes.total_size + # Find the device with the max mem size + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + # AOT based partition + if partitioner_config.mode == PartitionMode.aot_based: + self.aot_based_partition( + partitioner_config.node_to_partition_mapping, + partitioner_config.partition_to_logical_device_mapping, + ) + # Single partition if the whole module can be fit into one device + elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: + self.find_single_partition( + total_size_of_graph, logical_device_id=device_with_max_mem.logical_id + ) + elif total_size_of_graph > sum(d.available_mem_bytes for d in self.devices): + raise RuntimeError("Devices have no enough memory for the module") + else: + # Sparse nn based partition + if partitioner_config.mode == PartitionMode.sparse_nn: + available_mem_bytes = self.devices[0].available_mem_bytes + if not all( + device.available_mem_bytes == available_mem_bytes + for device in self.devices + ): + raise RuntimeError("All devices must have same memory size!") + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition + self.sparse_nn_partition(available_mem_bytes) + # Cost aware partition + elif partitioner_config.mode == PartitionMode.cost_aware: + self.cost_aware_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + # KL based partition + elif partitioner_config.mode == PartitionMode.kl_based: + self.kl_based_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + else: + self.size_based_partition() + + # Saturate host if possible. + if partitioner_config.saturate_host: + self.saturate_host() + + # Partition the graph module based on the partition assignment. + module_with_submodules = self.do_partition() + + # The DAG contains DAGNodes with info of each partition's input nodes, output nodes + # and how partitions are connected. + dag = self.dump_dag(module_with_submodules) + ret = PartitionResult(dag, module_with_submodules) + return ret + + def find_single_partition( + self, total_size_of_graph, logical_device_id: int = 0 + ) -> None: + """Fit the whole fx module into one device""" + partition_0 = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op == "output": + # Skip the output node, but there can + # be nodes after the output in certain cases. + continue + partition_0.nodes.add(node) + partition_0.used_mem_bytes = total_size_of_graph + partition_0.logical_device_ids = [logical_device_id] + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def size_based_partition(self) -> None: + """This method is to partition the fx module based on memory size. + It uses greedy approach. The result may not be the best. + The basic idea is: + Step 1: + Find a device which has enough memory to fit the current node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + and then try to map those partitions into logical devices with enough mem left. + """ + + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device("", -1, -1) + for d in self.devices: + if ( + d not in occupied_devices + and d.available_mem_bytes >= mem_size_needed + ): + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + "is too large to fit any device") + occupied_devices.append(device) + return device + + # Track partition and its left mem size + partition_to_left_mem_bytes: dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: list[Device] = [] + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if there are devices left + if len(self.partitions) <= len(self.devices): + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) + # Update available mem for the current partition + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into current partition + if ( + partition_to_left_mem_bytes[partition] + < total_size_of_input_nodes + ): + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device is left + # Create the first single node partition for the current node + self.create_single_node_partition(node) + continue + # Some devices are still left + # Create a new partition with a mem size that is enough for the current node + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of( + node, partition.nodes + ) + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + # Create single node partitions if no device is left + else: + self.create_single_node_partition(node) + reorganize_partitions(self.partitions) + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + # Mapping all partitions into device + found_partition_to_device_mapping = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_partition_to_device_mapping: + raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") + return + + def saturate_host(self) -> None: + """Saturate host by assigning replicates to unused devices with enough memory. + It uses a greedy approach to find a next available set of devices to place all split + partitions: For each used device, it searches for an idle device with minimal memory + size that can hold all the partition located on that device; If the search is successful + for all used devices, it then assigns the new devices' logical ID to the corresponding + partition. + """ + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(self.partitions, self.devices) + + assert len(no_device_partitions) == 0, ( + f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + ) + + # Devices that hold partitions + used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] + # Track replicates of the assigned devices + replicated_device_to_used_device: dict[Device, Device] = {} + + while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( + self.devices + ): + # Success flag for this round + success = True + # Devices that have not been assigned + idle_devices = [ + d + for d in self.devices + if d not in used_devices and d not in replicated_device_to_used_device + ] + # Temporary mapping from replicated device to original device + temp_replicate_mapping = {} + + # Find a new device to replicate all partitions on an used device + for used_device in used_devices: + # Idle devices that have enough memory + available_devices = [ + d + for d in idle_devices + if d.available_mem_bytes + >= used_device.available_mem_bytes + - device_to_left_mem_bytes[used_device] + ] + if len(available_devices) == 0: + success = False + break + new_device = min(available_devices, key=lambda d: d.available_mem_bytes) + idle_devices.remove(new_device) + temp_replicate_mapping[new_device] = used_device + + if not success: + break + replicated_device_to_used_device.update(temp_replicate_mapping) + + # Update logical device IDs assigned to the partitions + for ( + replicate_device, + original_device, + ) in replicated_device_to_used_device.items(): + logical_id = replicate_device.logical_id + for partition in device_to_partitions[original_device]: + partition.logical_device_ids.append(logical_id) + for p in self.partitions: + print(p.logical_device_ids) + + def do_partition(self) -> GraphModule: + """Return a new fx module with submodule nodes (partitions).""" + module_with_submodules = split_module( + self.graph_module, + self.torch_module, + lambda node: self.node_to_partition[node], + ) + return module_with_submodules + + def dump_dag(self, module_with_submodules: GraphModule) -> DAG: + """Return the dag structure and the new fx module with submodules.""" + dag = DAG() + for node in module_with_submodules.graph.nodes: + if node.op == "output": + break + if node.op in {"placeholder", "get_attr"}: + continue + if node.target == operator.__getitem__: + continue + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # When a node has two or more output nodes, + # it outputs its result to 'getitem' nodes. + # Those 'getitem' nodes are the output node for this node. + # Otherwise, the output node is this node itself. + if len(node.users) > 1: + output_nodes = list(node.users) + else: + output_nodes = [node] + partition_id = int(node.name.rsplit("_", 1)[-1]) + device_ids = self.partitions[partition_id].logical_device_ids + size_bytes = self.partitions[partition_id].used_mem_bytes + dag.create_node( + node, list(input_nodes), output_nodes, device_ids, size_bytes + ) + return dag + + def create_partition(self) -> Partition: + """Create a partition and append it to self.partitions.""" + partition_id = len(self.partitions) + partition = Partition(partition_id) + self.partitions.append(partition) + return partition + + def create_single_node_partition(self, node): + """Create a partition for a single node""" + partition = self.create_partition() + partition.add_node(node) + return + + def sparse_nn_partition(self, available_mem_bytes: int) -> None: + """This method partition a sparse nn module. + It is size based partition but different from size_based_partition, + it only works when all the devices have same memory size (available_mem_bytes). + In the future, devices with different mem sizes will be supported like size_based_partition. + It first traverse all the nodes and do the partitions based on the same memory size. + If the current partition has no enough memory left for a new op node + (call_module, call_method, call_function), a new partition is created. + When crossing the boundary between non-embedding nodes and embedding nodes, + a new partition is created regardlessly. + For example, if the current node is a non-embedding node but the next node is an + embedding node, a new partition is created for the next node. + After the partition, the partitions are combined as much as possible. + The rule is that a non-embedding partition only + combines with another non-embedding one. + So as the embedding partitions. + """ + + def combine_partitions_based_on_size( + partitions: list[Partition], available_mem_bytes: int + ) -> None: + """Combining small partitions together to keep as less partitions as possible. + Here is an example of the algorithm to do this: + Assume some partitions, we first sort them based on partition used memory size. + [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] + The available memory is 10. + step 1: self.find_partition_to_combine_based_on_size() + First, mark bfs level for each partition + Second, look the smallest partition, partition_4: 10 - 1 = 9 + It means any partition has a used memory equal or less than 9 could combine this partition + We go from the largest and selection partition_0. + Check the bfs level for two partitions, if the level difference is less than 2, + it can be combined. + step 2: repeat step 1 until no partitions can be combined + """ + find_combination = True + while find_combination: + # Sort partitions based on memory size + sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) + # Mark bfs level + get_bfs_level_partition(self.partitions) + find_combination, partitions = find_partition_to_combine_based_on_size( + sorted_partitions, available_mem_bytes, partitions + ) + return + + def calculate_mem_bytes_needed(p1, p2): + """Given two partitions, calculate how many mem bytes + are needed if two partitions are combined + """ + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed + + def find_partition_to_combine_based_on_size( + sorted_partitions: list[Partition], + available_mem_bytes: int, + partitions: list[Partition], + ) -> tuple[bool, list[Partition]]: + """step 1 in combine_partition_based_on_size()""" + find_combination = False + smallest_partition = sorted_partitions.pop(0) + for p in sorted_partitions[::-1]: + if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: + # Calculate how many bytes needed if combined + mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) + if mem_bytes_needed <= available_mem_bytes: + combine_two_partitions(p, smallest_partition, self.partitions) + partitions.remove(smallest_partition) + partitions.remove(p) + partitions.append(self.partitions[-1]) + find_combination = True + break + return find_combination, partitions + + def reset_partition_in_sparse_nn(partition, new_partition=True): + """If crossing the boundary between non-embedding nodes and + embedding nodes, create a new partition + """ + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + if new_partition: + partition = self.create_partition() + partition.left_mem_bytes = available_mem_bytes + return partition + return None + + def is_embedding_node(node: Node) -> bool: + """Check if a node is an embedding node""" + if node.op == "call_module": + submodule = self.graph_module + for atom in str(node.target).split("."): + if not hasattr(submodule, atom): + raise RuntimeError( + f"Module {submodule} has no attribute {atom}" + ) + submodule = getattr(submodule, atom) + if "Embedding" in str(submodule): + return True + return False + + # Track embedding partitions and non-embedding partitions separately + embedding_partitions: list[Partition] = [] + non_embedding_partitions: list[Partition] = [] + # A Flag to check the boundary + in_embedding_region: bool = False + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if crossing the boundary between embedding nodes and non embedding nodes + if is_embedding_node(node) != in_embedding_region: + # Crossing the boundary + # Check if the current partition is an empty partition + if partition.used_mem_bytes != 0: + # The current partition isn't an empty partition. Create a new one. + partition = reset_partition_in_sparse_nn(partition) + in_embedding_region = not in_embedding_region + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if ( + total_size_of_input_nodes + partition.used_mem_bytes + > available_mem_bytes + ): + partition = reset_partition_in_sparse_nn(partition) + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes > available_mem_bytes: + raise RuntimeError( + node.target + "is too large to fit into a device" + ) + partition.add_node(node) + reset_partition_in_sparse_nn(partition, new_partition=False) + # Set parents and children for partitions + set_parents_and_children(self.partitions) + # Combining non-embedding partitions + combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) + # Combining embedding partitions + combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) + total_size_of_non_embedding_partitions = 0 + for partition in non_embedding_partitions: + total_size_of_non_embedding_partitions += partition.used_mem_bytes + # Check if devices are enough for all partitions + if len(embedding_partitions) > len(self.devices): + msg = ( + "Need " + + str(len(embedding_partitions)) + + " devices, but only " + + str(len(self.devices)) + + " provided" + ) + raise RuntimeError(msg) + occupied_devices = [] + for i, partition in enumerate(embedding_partitions): + # Check if all non-embedding partitions can fit into embedding partition devices + if ( + total_size_of_non_embedding_partitions + partition.used_mem_bytes + > available_mem_bytes + ): + raise RuntimeError( + "partition_" + + str(partition.partition_id) + + "(embedding partition) and non embedding partitions can not fit into one device" + ) + else: + # Add logical device to the partition + partition.logical_device_ids = [self.devices[i].logical_id] + occupied_devices.append(self.devices[i].logical_id) + # Add logical devices to the non_embedding_partitions + for partition in non_embedding_partitions: + partition.logical_device_ids = occupied_devices + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def cost_aware_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: dict[Node, NodeLatency], + ) -> None: + """This method is to partition the fx module based on the cost. + The cost is the total latency of running the whole fx module. + In partitioner_utils.py, the cost model is built. + The cost aware partition algorithm is: + #1. At every beginning, each node is a partition. + Then we map all the partitions to the devices + and calculate the cost + #2. Then try to pre-combine any two of the partitions if the two + partitions can be combined. + (the bfs level is less than 2 or two partitions are connected and + can find partition to device mapping) + See if any partition pair could reduce the current cost. + Choose the pair that shows the minimum cost and then combine them + #3. Repeat #2 until the cost cannot be reduced. + """ + + def try_combining_partitions(p0_index, p1_index, partitions) -> float: + """Given two partitions and a list of partitions, combine these two partitions + and see what is the cost of the modified partition list + """ + p0 = partitions[p0_index] + p1 = partitions[p1_index] + """If two partitions' bfs level are less than 2 or two partitions are connected to each other, + then they can be combined + """ + if ( + (abs(p0.bfs_level - p1.bfs_level) <= 1) + or (p0 in p1.parents) + or p0 in (p1.children) + ): + combine_two_partitions(p0, p1, partitions) + # Check if a circular dependency exists after combining + if check_dependency(partitions[-1]): + return float("inf") + # Check if the modified partition list can be mapped to devices after combination + reset_partition_device(partitions) + found_deivce = get_device_to_partitions_mapping( + partitions, self.devices + ) + if not found_deivce: + return float("inf") + # Calculate the new cost + partition_to_latency_mapping = get_partition_to_latency_mapping( + partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + return cost + # If two partition can not be combined, the cost is inf + return float("inf") + + def search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) -> bool: + """Given transfer rate between partitions and each node's latency, + find two partitions to combine so the cost of the partitions can + be reduced. + The algorithm is : + 1. Go through all the partition pairs and see + if any pair of partitions can be combined. + 2. Calculate the cost after the combination. + 3. Select the minimum cost and combine its corresponding partition pair. + """ + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + if len(self.partitions) == 1: + return False + partition_pair: list[int] = [] + for i in range(len(self.partitions) - 1): + for j in range(i + 1, len(self.partitions)): + # Try to combine the partition pair + # and see the new cost after combination + new_cost = try_combining_partitions(i, j, self.partitions[:]) + if new_cost <= cost: + partition_pair = [i, j] + cost = new_cost + reorganize_partitions(self.partitions) + # If a partition pair is found, combine them + if len(partition_pair) != 0: + p0 = self.partitions[partition_pair[0]] + p1 = self.partitions[partition_pair[1]] + combine_two_partitions(p0, p1, self.partitions) + get_bfs_level_partition(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return len(partition_pair) != 0 + + for node in self.graph_module.graph.nodes: + if node.op not in {"placeholder", "get_attr", "output"}: + self.create_single_node_partition(node) + # Set up parent partitions and children partitions for each partition + set_parents_and_children(self.partitions) + # Get bfs level for each partition + get_bfs_level_partition(self.partitions) + find_combination = True + while find_combination: + # Search for a pair partition to generate the minimum new cost, + # then combine them + find_combination = search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) + # Make sure all partitions are set up correctly + reorganize_partitions(self.partitions) + # Set up node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def kl_based_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: dict[Node, NodeLatency], + ) -> None: + """This function is a cost aware partition based + on Kernighan-Lin algorithm. + First, the graph is partitioned using size_based_partition. + Then, each node is swapped with any other node in a different + partition, and at the same time, the cost is estimated after + the swapping. + For example, we have nodes n0, n1, n2, n3 and n4. + Using size_based_partition, n0 and n1 are in Partition p0. + n2, n3 and n4 in Partition p1. The current cost is estimated. + We first tried using n0 to swap with n2 from the other partition. + Then we see that swapping n0 and n2 shows a lower cost + than the current cost and it is the minimum among other pairs like + (n0, None)(This means moving n0 to Partition without swapping other nodes), + (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost + as the current cost. + Then We repeat this process for all the other nodes until all swapping pairs + are tried. + """ + + def swap_nodes(n0, n1, p0, p1): + # Either n0 or n1 could be None + # That means we simply move the node + # to another partition + if n0 is not None: + p0.remove_node(n0) + p1.add_node(n0) + if n1 is not None: + p0.add_node(n1) + p1.remove_node(n1) + + def try_swap_nodes( + n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + cost = float("inf") + swap_nodes(n0, n1, p0, p1) + # Reorganize partitions after swapping + reorganize_partitions(self.partitions) + # Check if there is a circular dependency after swapping + if (not check_dependency(p0)) and (not check_dependency(p1)): + reset_partition_device(self.partitions) + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Check if all partitions can be mapped to logical devices after swapping + found_device = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_device: + cost = float("inf") + else: + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Swap back and reset all partitions back to original + swap_nodes(n1, n0, p0, p1) + reorganize_partitions(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return cost + + def swap_node_to_partition( + node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + """This function helps to swap one node from partition p0 + with all the nodes in another partition p1 + """ + p1_nodes = list(p1.nodes) + [None] + min_cost = float("inf") + node_pair: list[Node] = [] + for n1 in p1_nodes: + # Ignore the node if it is not a op node + if n1 is not None and n1.op in {"placeholder", "get_attr"}: + continue + # Try swapping node in p0 with n1 in p1 + cost = try_swap_nodes( + node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ) + if cost < min_cost: + node_pair = [node, n1] + min_cost = cost + return cost, node_pair # type: ignore[possibly-undefined] + + # First use size_base_partition + self.size_based_partition() + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Calculate the cost of the partitions + cost = get_latency_of_partitioned_graph( + self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec + ) + # Keep tracking the node pair that shows the better cost + node_pair: list[Node] = [] + # Keep tracking the partition pair of node pair + partition_pair: list[Partition] = [] + # Collect all the op nodes from the graph + op_nodes = [ + n + for n in self.graph_module.graph.nodes + if n.op not in {"placeholder", "get_attr", "output"} + ] + for node in op_nodes: + # Find which partition the current node belongs + p0_index = self.node_to_partition[node] + p0 = self.partitions[p0_index] + # Go through all the other partitions to swap + # with other nodes from those partitions + for p1_index, _ in enumerate(self.partitions): + if p0_index != p1_index: + p1 = self.partitions[p1_index] + new_cost, new_node_pair = swap_node_to_partition( + node, + p0, + p1, + node_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Update the cost + # Track the swapped node pair and their partitions + if new_cost < cost: + cost = new_cost + node_pair = new_node_pair + partition_pair = [p0, p1] + # Do the swapping after trying all the nodes from a partition + if len(node_pair) != 0: + swap_nodes( + node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] + ) + reorganize_partitions(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + reorganize_partitions(self.partitions) + # Mapping the device to the partition + get_device_to_partitions_mapping(self.partitions, self.devices) + return + + def aot_based_partition( + self, node_to_partition_mapping, partition_to_logical_device_mapping + ): + """This function helps to rebuild the partitions given the nodes and its + corresponding partition id + """ + partition_id_to_partition_mapping: dict[int, Partition] = {} + self.node_to_partition = node_to_partition_mapping + for node in self.node_to_partition: + partition_id = self.node_to_partition[node] + # If the requested partition has not been created, create the partition + if partition_id not in partition_id_to_partition_mapping: + partition = Partition(partition_id) + self.partitions.append(partition) + partition_id_to_partition_mapping[partition_id] = partition + partition.logical_device_ids = partition_to_logical_device_mapping[ + partition_id + ] + else: + partition = partition_id_to_partition_mapping[ + self.node_to_partition[node] + ] + # Add the current node into the partition + partition.add_node(node) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py new file mode 100644 index 0000000000000000000000000000000000000000..525014bf1e80e2c4a654a28b923e6216717872be --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/const_fold.py @@ -0,0 +1,311 @@ +# mypy: allow-untyped-defs +import re +from typing import Callable, Optional, Union + +import torch.fx +from torch.fx.node import map_arg +from torch.fx.passes.split_module import split_module + + +__all__ = [ + "FoldedGraphModule", + "get_unique_attr_name_in_module", + "split_const_subgraphs", +] + + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + fx_const_folded_attrs_name: Optional[str] = None, + device_for_folded_attrs: str = "cuda", + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.has_folding_been_run = False + self.fx_const_folded_attrs_name = fx_const_folded_attrs_name + self.device_for_folded_attrs = device_for_folded_attrs + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if ( + self.const_subgraph_module is None + or self.fx_const_folded_attrs_name is None + ): + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. Note that single attr const fold + # subgraphs output a single Tensor while multiple outputs are returned as + # Tuple[Tensor,]. + folded_attrs = self.const_subgraph_module() + + def _create_param(i): + return torch.nn.Parameter( + i.detach().clone() + if not isinstance(i, int) + else torch.Tensor([i]).to(device=self.device_for_folded_attrs), + requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, + ) + + params = ( + torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) + if isinstance(folded_attrs, tuple) + else _create_param(folded_attrs) + ) + setattr(self, self.fx_const_folded_attrs_name, params) + + +def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): + """ + Given `gm` and some graph module which is called with target name `inline_mod_name`, + this helper will inline all of the nodes from that called graph module into `gm`. + """ + # Fetch the inner graph module that we want to inline inside `gm`. + inline_mod = dict(gm.named_modules())[inline_mod_name] + assert isinstance(inline_mod, torch.fx.GraphModule) + call_mod_node_to_replace = None + for node in gm.graph.nodes: + if node.op == "call_module" and node.target == inline_mod_name: + call_mod_node_to_replace = node + break + assert call_mod_node_to_replace is not None + + # Now actually do the swap. Note that we have to keep track of new nodes that are + # copied into `gm` -- we do this via replacement_mapping. + call_mod_args = call_mod_node_to_replace.args + call_mod_kwargs = call_mod_node_to_replace.kwargs + + replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {} + ph_count = 0 + + def replacement_fn(node): + new_node = replacement_mapping[node] + new_node.meta = node.meta.copy() + return new_node + + for inline_node in inline_mod.graph.nodes: + if inline_node.op == "placeholder": + replacement_mapping[inline_node] = ( + call_mod_kwargs[inline_node.name] + if inline_node.name in call_mod_kwargs + else call_mod_args[ph_count] + ) + + ph_count += 1 + continue + + if inline_node.op == "output": + outputs = inline_node.args[0] + output_replacements = map_arg(outputs, replacement_fn) + call_mod_node_to_replace.replace_all_uses_with(output_replacements) + continue + + with gm.graph.inserting_before(call_mod_node_to_replace): + new_node = gm.graph.node_copy(inline_node, replacement_fn) + replacement_mapping[inline_node] = new_node + + gm.graph.eliminate_dead_code() + + +def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: + """ + Make sure the name is unique (in a module) and can represents an attr. + """ + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + + return name + + +def split_const_subgraphs( + module: Union[torch.nn.Module, torch.fx.GraphModule], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + device_for_folded_attrs: str = "cpu", +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + if not isinstance(module, torch.fx.GraphModule): + mod_traced = torch.fx.symbolic_trace(module) + else: + mod_traced = module + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op != "get_attr" and not set(node.all_input_nodes).issubset( + const_nodes + ): + continue + + # If provided skip folding function says to skip, then skip. + if skip_folding_node_fn and skip_folding_node_fn(node): + continue + + # Skip folding side-effectful functions + if node.is_impure(): + continue + + # Must be a constant foldable node at this point. + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + const_mod_name, non_const_mod_name = "submod_0", "submod_1" + # Safely get submod_1 in case there are no non-const nodes + const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None) + + # The module that a call_module node refers to gets copied to submodules during split. + # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to + # attach inlined modules to `split` as it's the owning module now. + for node in non_const_gm.graph.nodes if non_const_gm else []: + if node.op == "call_module": + setattr(split, node.target, getattr(non_const_gm, node.target)) + for node in const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(const_gm, node.target)) + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside const_gm, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to const_gm must be constants accessible via get_attr. + call_const_gm_args = None + for node in split.graph.nodes: + if node.op == "call_module": + if node.target == const_mod_name: + call_const_gm_args = node.args + break + assert call_const_gm_args is not None + + # Here we do the actual replacement of placeholders to get_attrs. Note that here we + # set the const_gm.graph into a new root_const_gm with split as the root module, + # because we are fetching attributes directly from the root module, instead of + # fetching them from const_gm. Example: The const_gm must have some format like: + # graph(): + # %inp : [num_users=1] = placeholder[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) + # return add + # We replace that with the following, which does not have any placeholders: + # graph(): + # %inp_1 : [num_users=1] = get_attr[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) + # return add + root_const_gm = torch.fx.GraphModule(split, const_gm.graph) + + # The order of placeholders in the const_gm graph should match the order of + # args in the outer module, so we can simply use an index for the + # placeholder mapping + ph_idx = 0 + for node in root_const_gm.graph.nodes: + if node.op == "output": + multiple_outputs = isinstance(node.args[0], tuple) + continue + if node.op != "placeholder": + continue + assert ph_idx < len(call_const_gm_args) + in_node = call_const_gm_args[ph_idx] + ph_idx += 1 + assert in_node.op == "get_attr" + with root_const_gm.graph.inserting_before(node): + new_node = root_const_gm.graph.get_attr(in_node.target) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + root_const_gm.graph.erase_node(node) + assert "multiple_outputs" in locals() + + # Now find the call to const_gm inside split, and replace it with a getattr to the + # folded tensor(s) that result from constant folding. Note that we don't need to + # worry about whether this is one or more tensors because the original graph + # correctly uses getitem to extract individual tensors if there are multiple folded. + fx_const_folded_attrs_name = get_unique_attr_name_in_module( + mod_traced, "_FX_CONST_FOLDED_ATTRS" + ) + setattr( + split, + fx_const_folded_attrs_name, + torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] + ) + for node in split.graph.nodes: + if node.op == "call_module" and node.target == const_mod_name: + with node.graph.inserting_before(node): + folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) + folded_attrs.meta = node.meta.copy() + node.replace_all_uses_with(folded_attrs) + break + + # Finally, inline the non-constant submod (if it exists) into the split submod. + # This is so that the original caller who may have passed in a graph module will + # get back out a graph module whose graph is traced to the same granularity. + if hasattr(split, non_const_mod_name): + _inline_module(split, non_const_mod_name) + + split.graph.eliminate_dead_code() + + return FoldedGraphModule( + split, + split.graph, + root_const_gm.graph, + fx_const_folded_attrs_name, + device_for_folded_attrs, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/debug.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..b87dee9db9c73f0b4ea1a0a27682a167e125a71d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/debug.py @@ -0,0 +1,33 @@ +from collections.abc import Sequence + +import torch.fx as fx + + +__all__ = ["set_trace"] + + +def set_trace(gm: fx.GraphModule) -> fx.GraphModule: + """ + Sets a breakpoint in `gm`'s generated python code. It drops into pdb when + `gm` gets run. + + Args: + gm: graph module to insert breakpoint. It is then recompiled for it to + take effect. + + Returns: + the `gm` with breakpoint inserted. + """ + + def insert_pdb(body: Sequence[str]) -> list[str]: + return ["import pdb; pdb.set_trace()\n", *body] + + with gm.graph.on_generate_code( + make_transformer=lambda cur_transform: ( + # new code transformer to register + lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)) + ) + ): + gm.recompile() + + return gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py new file mode 100644 index 0000000000000000000000000000000000000000..3b15ae0a6739cf91b0bbe3cd8ef7da0cefe091d5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/graph_gradual_typechecker.py @@ -0,0 +1,1024 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from functools import reduce +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +import sympy + +import torch +from torch.fx.experimental.refinement_types import Equality +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_INFERENCE_RULES: dict[Target, Callable] = {} +_REFINEMENT_RULES: dict[Target, Callable] = {} +_RULES: dict[Target, Callable] = {} + +__all__ = [ + "GraphTypeChecker", + "Refine", + "adaptiveavgpool2d_check", + "adaptiveavgpool2d_inference_rule", + "add_inference_rule", + "all_eq", + "bn2d_inference_rule", + "broadcast_types", + "calculate_out_dimension", + "conv2d_inference_rule", + "conv_refinement_rule", + "conv_rule", + "element_wise_eq", + "expand_to_tensor_dim", + "first_two_eq", + "flatten_check", + "flatten_inference_rule", + "flatten_refinement_rule", + "get_attr_inference_rule", + "get_greatest_upper_bound", + "get_parameter", + "linear_check", + "linear_inference_rule", + "linear_refinement_rule", + "maxpool2d_check", + "maxpool2d_inference_rule", + "register_algebraic_expressions_inference_rule", + "register_inference_rule", + "register_refinement_rule", + "relu_inference_rule", + "reshape_inference_rule", + "transpose_inference_rule", +] + + +def expand_to_tensor_dim(t, n): + """ + Expand a type to the desired tensor dimension if possible + Raise an error otherwise. + - t is the given type + - n is a number of dimensions to expand to + """ + if t == Dyn: + dims = [Dyn] * n + return TensorType(tuple(dims)) + elif isinstance(t, TensorType): + if len(t.__args__) != n: + raise TypeError( + f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}" + ) + return t + else: + raise TypeError(f"Cannot match the type {t}") + + +def broadcast_types(t1, t2): + """ + Applies broadcasting to both given types such that they + become consistent with eachother and returns two new + resulting types + """ + + # if either type is Dyn, do nothing since the types are already consistent + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return t1, t2 + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + s1 = len(t1.__args__) + s2 = len(t2.__args__) + + new_t1 = list(t1.__args__) + new_t2 = list(t2.__args__) + + # We make the types the same length which is the first requirement + # for consistency + if s1 > s2: + for i in range(s1 - s2): + new_t2.insert(0, 1) + + elif s2 > s1: + for i in range(s2 - s1): + new_t1.insert(0, 1) + + # we replace occurrences of "1" with each tensor with + # the corresponding type from the other tensor + for i, (x, y) in enumerate(zip(new_t1, new_t2)): + if x == 1: + new_t1[i] = y + elif y == 1: + new_t2[i] = x + + # at this point our tensors should be consistent + # and we can apply the element-wise operation and find the right dimension + # for the output of the operation + (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) + return (t1, t2) + else: + raise TypeError(f"Cannot broadcast types {t1} and {t2}") + + +def register_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _INFERENCE_RULES: + raise RuntimeError(f"Inference rule already registered for {call_target}!") + _INFERENCE_RULES[call_target] = fn + return fn + + return register + + +def register_refinement_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _REFINEMENT_RULES: + raise RuntimeError(f"Refinement rule already registered for {call_target}!") + _REFINEMENT_RULES[call_target] = fn + return fn + + return register + + +def register_algebraic_expressions_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _RULES: + raise RuntimeError(f"Rule already registered for {call_target}!") + _RULES[call_target] = fn + return fn + + return register + + +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def add_inference_rule(n: Node): + """ + Apply the addition inference rule. This includes: + - scalar addition + - broadcasting semantics + + Note that we always return the least precise type between + the operands (after applying broadcasting) to be the final type of the operation + + Note that we do not modify the operand types themselves after applying broadcasting + to them. We only use them to calculate the final type + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + t1 = n.args[0].type + t2 = n.args[1].type + + # handle scalar addition + if t1 == int and isinstance(t2, TensorType): + n.type = t2 + return n.type + + # handle scalar addition + elif t2 == int and isinstance(t1, TensorType): + n.type = t1 + return n.type + + # we bring the new types to the point where + # we can check for consistency + # any inconsistency would not have been caused + # by broadcasting at this point + (new_t1, new_t2) = broadcast_types(t1, t2) + + if new_t1 != t1 or new_t2 != t2: + n.meta["broadcast"] = True + n.meta[str(n.args[0])] = new_t1 + n.meta[str(n.args[1])] = new_t2 + + else: + n.meta["broadcast"] = False + + new_t1 = t1 if not n.meta["broadcast"] else new_t1 + new_t2 = t2 if not n.meta["broadcast"] else new_t2 + + # we check for consistency between the new types + if is_consistent(new_t1, new_t2): + # we return the less precise type because + # broadcasting may have happened + # for operands with shape [1,2,Dyn] and [1,2,1] + # we have to assign the node [1,2,Dyn] + if is_more_precise(new_t1, new_t2): + n.type = new_t2 + else: + n.type = new_t1 + return n.type + else: + raise TypeError( + f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}." + f" Types should match " + ) + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, traced): + """ + The current getattr rule only handles the shape attribute + Can be extended to other attributes + The most representitive type we have is "Dyn" but the system + can be extended with more types, such as a type to represent shapes + """ + attr_name = n.args[1] + + if attr_name == "shape": + n.type = Dyn + else: + raise TypeError("Not yet implemented") + + # TODO. We leave it like this till we add a type to represent tensor sizes + return n.type + + +@register_inference_rule(torch.transpose) +def transpose_inference_rule(n: Node): + """ + We check that dimensions for the transpose operations + are within range of the tensor type of the node + """ + if n.target == torch.transpose: + assert isinstance(n.args[0], Node) + t = n.args[0].type + + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + dim1, dim2 = n.args[1], n.args[2] + + if t == Dyn: + n.type = Dyn + return n.type + + elif isinstance(t, TensorType): + if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): + new_type = list(t.__args__) + new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] + final = TensorType(new_type) + n.type = get_greatest_upper_bound(n.type, final) + return n.type + else: + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) + else: + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node): + """ + Without dynamism, the rule checks that the + product of the elements of the argument tensor + type is equal to the product of the elements + of the required shape. We gradualize this rule + by adding a case to handle fully dynamic input + as well as input where some of the tensor dimensions + are unknown. In this case we check for divisibility + """ + assert isinstance(n.args[0], Node) + t1 = n.args[0].type + + assert isinstance(n.args[1], list) + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) + + # if we do not know the original tensor dimension, + # we return the required dimension + if t1 == Dyn: + n.type = t2_type + return t2_type + + # if any of the dimensions are unknown, + # we check for divisibility + elif isinstance(t1, TensorType): + assert isinstance(t1, TensorType) + a = [e if e != Dyn else 1 for e in t1.__args__] + p1 = reduce(operator.mul, a) + p2 = reduce(operator.mul, t2) + if p1 % p2 == 0 or p2 % p1 == 0: + n.type = t2_type + return t2_type + else: + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + else: + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + + +@register_inference_rule(BatchNorm2d) +def bn2d_inference_rule(n: Node, module_instance): + """ + Given a BatchNorm2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - t is consistent with t' + - x_2 is consistent with the module's num_features + - x_2' is consistent with the module's num_features + output type: the more precise type of t and t' + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + n.type = expand_to_tensor_dim(n.type, 4) + + # we check the conditions on the incoming argument + # and any existing annotation + # we also check for consistency between both annotations + if ( + is_consistent(arg_type.__args__[1], module_instance.num_features) + and is_consistent(n.type.__args__[1], module_instance.num_features) + and is_consistent(arg_type, n.type) + ): + # we choose the more precise type + # to be the node type + # so if an incoming argument has more type information + # we set this node's type to be the argument type + n.type = get_greatest_upper_bound(arg_type, n.type) + return n.type + else: + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) + + +def calculate_out_dimension(d_in, module_instance, index): + """ + For calculating h_in and w_out according to the conv2D documentation + """ + padding = ( + (module_instance.padding, module_instance.padding) + if isinstance(module_instance.padding, int) + else module_instance.padding + ) + kernel_size = ( + (module_instance.kernel_size, module_instance.kernel_size) + if isinstance(module_instance.kernel_size, int) + else module_instance.kernel_size + ) + stride = ( + (module_instance.stride, module_instance.stride) + if isinstance(module_instance.stride, int) + else module_instance.stride + ) + dilation = ( + (module_instance.dilation, module_instance.dilation) + if isinstance(module_instance.dilation, int) + else module_instance.dilation + ) + + DIMENSION_TYPES = (int, sympy.Symbol) + + if d_in == Dyn: + return Dyn + + elif isinstance(d_in, DIMENSION_TYPES): + n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1 + + return (n // stride[0]) + 1 + + else: + raise TypeError( + f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}" + ) + + +def get_greatest_upper_bound(type1, type2): + """ + Get the most precise type that's consistent with the given types + """ + if type1 == Dyn: + return type2 + elif type2 == Dyn: + return type1 + elif isinstance(type1, TensorType) and isinstance(type2, TensorType): + if not is_consistent(type1, type2): + raise TypeError(f"Inconsistent types {type1}, {type2}") + gub = [ + t1 if is_more_precise(t1, t2) else t2 + for (t1, t2) in zip(type1.__args__, type2.__args__) + ] + return TensorType(tuple(gub)) + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance): + """ + Given a Conv2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - x_2 is consistent with the module's in_channels + - let o = (x_1, out_channels, H_out, W_out) + then the output is the greatest upper bound of o and the existing node type t'. + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + curr_node_type = expand_to_tensor_dim(n.type, 4) + + if is_consistent(arg_type.__args__[1], module_instance.in_channels): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType( + (arg_type.__args__[0], module_instance.out_channels, h_out, w_out) + ) + gub = get_greatest_upper_bound(new_type, curr_node_type) + n.type = gub + return n.type + else: + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) + + +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + n.type = get_greatest_upper_bound(n.args[0].type, n.type) + return n.type + + +def maxpool2d_check(typ, module_instance): + """ + Applies the maxpool2d shape information to the input + this affects the last two dimensions + """ + new_type_list = list(typ.__args__) + if len(new_type_list) == 4 or len(new_type_list) == 3: + w_in = new_type_list[-1] + h_in = new_type_list[-2] + + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + + new_type_list[-1] = w_out + new_type_list[-2] = h_out + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f"Wrong size {typ} for {module_instance}") + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool2d_inference_rule(n: Node, module_instance): + """ + Given a MaxPool2D instance and a node check the following conditions: + - Input size matches size 3 or 4 + - Current node type is consistent with the output type we will calculate + - Input size matches output size and the last two dimensions of the output + are w_out and h_out. The remaining dimensions are the same as the input + - Our final result is the greatest upper bound of the output we calculate + and the current node type. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output = maxpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output, n.type) + return n.type + + +def linear_check(tensor_type, module_instance): + """ + Checks that an input tensor type satisfies the conditions for linear operation + and returns the output type based on in and out features given by module_instance + """ + if len(tensor_type.__args__) >= 2: + if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): + new_type_args = list(tensor_type.__args__) + new_type_args[-1] = module_instance.out_features + return TensorType(tuple(new_type_args)) + else: + raise TypeError( + f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}" + ) + else: + raise TypeError(f"Type {tensor_type} must have rank 2 or more.") + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance): + """ + Applies the shape information to the input then gets the greatest upper bound + of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = linear_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output_type, n.type) + return n.type + + +def adaptiveavgpool2d_check(tensor_type, module_instance): + output_size = module_instance.output_size + if isinstance(output_size, int): + output_size = [output_size, output_size] + elif isinstance(output_size, tuple): + output_size = list(output_size) + if output_size[0] is None: + output_size[0] = output_size[1] + if output_size[1] is None: + output_size[1] = output_size[0] + + new_type_list = list(tensor_type.__args__) + + if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: + new_type_list[-1] = output_size[1] + new_type_list[-2] = output_size[0] + + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}") + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptiveavgpool2d_inference_rule(n: Node, module_instance): + """ + The input and output sizes should be the same except for the last + two dimensions taken from the input, which represent width and height + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(n.type, output_type) + return n.type + + +def flatten_check(tensor_type, start_dim, end_dim): + l = len(tensor_type.__args__) + + start_dim = l if start_dim == -1 else abs(start_dim) + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: + my_args = list(tensor_type.__args__) + lhs = my_args[0:start_dim] + rhs = my_args[end_dim:] + mid = my_args[start_dim:end_dim] + if Dyn in mid: + mid = [Dyn] + else: + mid = [reduce(operator.mul, my_args[start_dim:end_dim])] + new_type_list = lhs + mid + rhs + return TensorType(tuple(new_type_list)) + else: + raise TypeError( + f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}" + ) + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node): + """ + Applies the flatten shape information to the input then gets the + greatest upper bound of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + output_type = flatten_check(n.args[0].type, start_dim, end_dim) + n.type = get_greatest_upper_bound(output_type, n.type) + + return n.type + + +class GraphTypeChecker: + def __init__(self, env, traced): + self.env = env + self.traced = traced + + def type_check(self): + """ + A gradual type checker for graphs + Effect: every node's field type will be + populated with a type after type-checking is done + """ + graph = self.traced.graph + + # type check every node with gradual type rules + # if any node does not type check return false + for n in graph.nodes: + self.type_check_node(n) + return True + + def type_check_node(self, n: Node): + """ + Type check a given fx node. + Current operations: + - Reshape + - Transpose + - Add + - Relu + - conv2d + - batchnorm2d + - flatten + - maxpool2d + - adaptiveavgpool2d + - linear + """ + if n.type is None: + n.type = Dyn + + if n.op == "placeholder": + return n.type + + elif n.op == "get_attr": + t = get_parameter(self.traced, n.target) # type: ignore[arg-type] + if isinstance(t.data, torch.Tensor): + n.type = TensorType(t.data.shape) + return n.type + + elif n.op == "call_function": + if n.target == getattr: + assert getattr in _INFERENCE_RULES + return _INFERENCE_RULES[n.target](n, self.traced) + + elif n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, module_instance) + else: + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") + + +@register_refinement_rule(Conv2d) +def conv_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(torch.nn.Linear) +def linear_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(BatchNorm2d) +@register_refinement_rule(torch.nn.ReLU) +def all_eq(n: Node): + """ + For operations where the input shape is equal to the output shape + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[i], args2[i]) for i in range(len(args1))] + return res + + +@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) +@register_refinement_rule(torch.nn.MaxPool2d) +def first_two_eq(n: Node): + """ + For operations where the first two dimensions of the input and output shape + are equal + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] + return res + + +@register_refinement_rule(torch.add) +@register_refinement_rule(operator.add) +def element_wise_eq(n: Node): + """ + For element-wise operations and handles broadcasting. + Note that after applying broadcasting to the arguments + we are able to determine if certain dimensions have not been broadcast + if they are symbolicallu equal. + + in this case, we can establish equality between those dimensions and the + corresponding output dimensions. + + Note that it takes two iterations for this result. One iteration to establish + equality between certain dimensions of the operands (requiring the whole solver + including unification) and another iteration to establish equality between the operands + and the resulting type, requiring another round of constraint generation and unificaiton. + """ + res = [] + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + arg_type1 = n.args[0].type + arg_type2 = n.args[1].type + if ( + isinstance(arg_type1, TensorType) + and isinstance(arg_type2, TensorType) + and isinstance(n.type, TensorType) + ): + args1, args2 = broadcast_types(arg_type1, arg_type2) + # by this point, we know that args1 and args2 are the same size. + a1 = args1.__args__ + a2 = args2.__args__ + a3 = n.type.__args__ + + # we would be here in the second iteration where we establish equality + # between operand type dimensions and the resulting type dimensions + r = [] + for x, y, z in zip(a1, a2, a3): + if x == y: + r.append(Equality(x, z)) + res = r + return res + + +@register_refinement_rule(torch.flatten) +def flatten_refinement_rule(n: Node): + """ + Generates equality constraints between the dimensions of the input and output + that will not be involved in the flatten operation + """ + assert isinstance(n.args[0], Node) + + eq_const = [] + + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): + l = len(n.type.__args__) + arg_type = n.args[0].type + start_dim = l if start_dim == -1 else start_dim + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): + eq_const.append(Equality(t1, t2)) + + for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): + eq_const.append(Equality(t1, t2)) + return eq_const + + +@register_algebraic_expressions_inference_rule(Conv2d) +def conv_rule(n: Node, module_instance): + """ + Represents the outout in terms of an algrbraic expression w.r.t + the input when possible + """ + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) + n.type = new_type + return new_type + + +class Refine: + """ + Symbolic shape inference. + Generates constraints over type variables. + Currently all constraints are equality constraints. + """ + + def __init__(self, traced): + self.constraints = [] + self.traced = traced + self.symbol_iter = itertools.count(start=0, step=1) + + def refine(self): + """ + Generates constraints for + every node in the graph based on + the operation. + """ + graph = self.traced.graph + for n in graph.nodes: + self.refine_node(n) + return True + + def symbolic_relations(self): + """ + Infers algebraic relations + """ + graph = self.traced.graph + for n in graph.nodes: + self.infer_symbolic_relations(n) + return True + + def replace_dyn_with_fresh_var(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if typ == Dyn: + new_symbol = Var(next(self.symbol_iter)) + return new_symbol + elif isinstance(typ, TensorType): + new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.replace_dyn_with_fresh_var(t) for t in typ] + elif isinstance(typ, tuple): + return (self.replace_dyn_with_fresh_var(t) for t in typ) + else: + return typ + + def convert_to_sympy_symbols(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if isinstance(typ, Var): + return sympy.symbols(str(typ)) + elif isinstance(typ, TensorType): + new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.convert_to_sympy_symbols(t) for t in typ] + elif isinstance(typ, tuple): + return (self.convert_to_sympy_symbols(t) for t in typ) + else: + return typ + + def refine_node(self, n: Node): + """ + Returns a list of equality constraints for + call_module and call_function nodes. + Models the relation between input and output dimensions + using constraints in case they are both tensors. + All operations used in resnet50 are defined. + """ + if n.type is None: + n.type = Dyn + + n.type = self.replace_dyn_with_fresh_var(n.type) + + if n.op == "call_function": + if n.target in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[n.target](n) + else: + pass + + if n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[type(module_instance)](n) + else: + pass + + if n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + + def infer_symbolic_relations(self, n: Node): + n.type = self.convert_to_sympy_symbols(n.type) + if n.op == "call_function": + if n.target in _RULES: + return _RULES[n.target](n) + else: + pass + + if n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _RULES: + return _RULES[type(module_instance)](n, module_instance) + else: + pass + + if n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + + +def get_parameter(traced, target: str): + """ + Returns the parameter given by ``target`` if it exists, + otherwise throws an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = traced.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + return param diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a51918f930ffdcdff63dc33ac1aac0a4e98bd0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,177 @@ +# mypy: allow-untyped-defs +import itertools +import operator + +import torch +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.node import Node +from torch.fx.passes.tools_common import legalize_graph + + +def split_result_tensors( + result: torch.Tensor, inputs: list[torch.Tensor] +) -> tuple[torch.Tensor, ...]: + """ + A free function for use in the merge_matmul graph transformation below that + splits the output from a merged matmul into the individual results for each + input tensor. + + Arguments: + result: The merged matmul result tensor. + inputs: The list of inputs that were merged into one for the matmul. + + Returns: + List of matmul results for each input tensor. + """ + # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we + # need an int even when tracing + if isinstance(result, torch.fx.Proxy): + splits = [0] * len(inputs) + else: + splits = [x.shape[0] for x in inputs] + + return torch.split(result, splits) + + +def may_depend_on(a: Node, b: Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: list[Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = symbolic_trace(in_mod) + + rhs_users: dict[Node, list[Node]] = {} + lhs_users: dict[Node, list[Node]] = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function( + torch.matmul, + ( + merge_mm_cat, + rhs, + ), + {}, + ) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_split = gm.graph.call_function( + split_result_tensors, (merge_mm, lhs), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint() + return gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc00be5ee7ae823e5fab5074867f133c77c3fe91 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/meta_tracer.py @@ -0,0 +1,311 @@ +# mypy: allow-untyped-defs +import builtins +import functools +import warnings +from typing import Any, Callable, Optional, Union + +import torch +import torch.fx + + +def embedding_override(self, input): + return torch.empty(*input.shape, self.weight.shape[-1], device="meta") + + +def nn_layernorm_override(self, input): + return input + + +def torch_relu_override(x): + return x + + +def torch_nn_relu_override(self, x): + return x + + +def functional_relu_override(x, inplace=False): + assert not inplace, "dont support inplace functional.relu for metatensor analysis" + return x + + +def torch_where_override(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") + + +def torch_abs_override(input, *, out=None): + assert out is None, "Dont support in-place abs for MetaTensor analysis" + return input + + +manual_meta_overrides: dict[Callable, Callable] = { + torch.nn.Embedding: embedding_override, + torch.nn.LayerNorm: nn_layernorm_override, + torch.relu: torch_relu_override, + torch.nn.functional.relu: functional_relu_override, + torch.nn.ReLU: torch_nn_relu_override, + torch.where: torch_where_override, + torch.abs: torch_abs_override, +} + + +def gen_constructor_wrapper(target): + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = None + + def check_has_proxy(v): + if isinstance(v, torch.fx.Proxy): + nonlocal proxy + proxy = v + + torch.fx.node.map_aggregate(args, check_has_proxy) + torch.fx.node.map_aggregate(kwargs, check_has_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy("call_function", target, args, kwargs) + else: + return target(*args, **kwargs) + + return wrapper, target + + +class MetaProxy(torch.fx.Proxy): + def install_tensor_meta(self, tensor_meta): + self._tensor_meta = tensor_meta + + def size(self, dim=None): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.size(*[dim] if dim else []) + return self.tracer.create_proxy( + "call_method", "size", (self, dim) if dim else (self,), {} + ) + + def dim(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.dim() + return self.tracer.create_proxy("call_method", "dim", (self,), {}) + + @property + def shape(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.shape + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "shape"), {} + ) + + @property + def dtype(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.dtype + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "dtype"), {} + ) + + @property + def device(self): + # Hack so we can track when devices are used. During meta-tensor propagation, + # replace these values with a constant 'meta' + return MetaDeviceAttribute(self, "device") + + def __getattr__(self, k): + if k == "_tensor_meta": + return self.__getattribute__(k) + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return MetaAttribute(self, k) + + +class MetaAttribute(MetaProxy): + def __init__(self, root, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): # type: ignore[override] + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) + + +class MetaDeviceAttribute(MetaAttribute): + pass + + +def proxys_to_metas(v): + if isinstance(v, MetaDeviceAttribute): + return "meta" + if isinstance(v, torch.fx.Proxy): + assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}" + assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta" + return v._tensor_meta + return v + + +class MetaTracer(torch.fx.Tracer): + allow_insert_stateless_mods: bool = True + + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] + + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + rv = super().create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + + if kind == "placeholder" and target in self.meta_args: + rv.install_tensor_meta(self.meta_args[target]) + return rv + + if target in self.orig_fns: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) + kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) + + if kind == "call_function": + meta_target = manual_meta_overrides.get(target, target) + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_method": + meta_target = getattr(args_metas[0], target) # type: ignore[index] + meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index] + elif kind == "call_module": + assert hasattr(self, "orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if mod_type in manual_meta_overrides: + meta_out = manual_meta_overrides[mod_type]( + mod, *args_metas, **kwargs_metas + ) # type: ignore[misc, arg-type] + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == "get_attr": + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + assert isinstance(attr_itr, torch.Tensor) + meta_out = attr_itr.to(device="meta") + finally: + self._disable_module_getattr = False + else: + return rv + + # TODO + assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet" + rv.install_tensor_meta(meta_out) + except Exception as e: + warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") + + return rv + + def getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + else: + return super().getattr(attr, attr_val, parameter_proxy_cache) + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + return super().call_module(m, forward, args, kwargs) + + def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: + """ + Helper method which tries to insert a module that was not declared as submodule. + """ + idx = 0 + mod_name = mod.__class__.__name__.lower() + path = f"{mod_name}_{idx}" + while hasattr(self.root, path): + path = f"{mod_name}_{idx}" + idx += 1 + + self.root.add_module(path, mod) + return path + + def path_of_module(self, mod: torch.nn.Module) -> str: + try: + return super().path_of_module(mod) + except NameError: + if ( + self.allow_insert_stateless_mods + and len(list(mod.parameters())) == 0 + and len(list(mod.buffers())) == 0 + ): + path = self._insert_module_as_submodule(mod) + self.prev_module = path + return path + raise + + def proxy(self, node): + return MetaProxy(node, self) + + def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] + assert isinstance(meta_args, dict) + self.meta_args = meta_args + + self.patched_torch_methods = { + target: gen_constructor_wrapper(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH + } + self.orig_fns = set() + + for name, (wrapper, orig) in self.patched_torch_methods.items(): + setattr(torch, name, wrapper) + self.orig_fns.add(orig) + + try: + graph = super().trace(root, concrete_args) + graph._tracer_extras = {"meta_args": meta_args} + return graph + finally: + for name, (_, orig) in self.patched_torch_methods.items(): + setattr(torch, name, orig) + + +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + meta_args: Optional[dict[str, torch.Tensor]] = None, + concrete_args: Optional[dict[str, Any]] = None, +) -> torch.fx.GraphModule: + tracer = MetaTracer() + graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + gm = torch.fx.GraphModule(tracer.root, graph, name) + return gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..8aca3e482c95f75506621475935a9d3b8b879d4e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -0,0 +1,643 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType + + +class Constraint: + pass + + +class Conj(Constraint): + def __init__(self, conjuncts): + """ + :param conjuncts: Conjunction of constraints + """ + self.conjucts = conjuncts + + def __eq__(self, other): + if isinstance(other, Conj): + return self.conjucts == other.conjucts and self.conjucts == other.conjucts + else: + return False + + def __repr__(self): + return f"And({self.conjucts})" + + +class Disj(Constraint): + def __init__(self, disjuncts): + """ + :param disjuncts: Disjunction of constraints + """ + self.disjuncts = disjuncts + + def __eq__(self, other): + if isinstance(other, Disj): + return ( + self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + ) + else: + return False + + def __repr__(self): + return f"Or({self.disjuncts})" + + +class Prod(Constraint): + def __init__(self, products): + """ + :param products: lists of dimensions to multiply + """ + self.products = products + + def __eq__(self, other): + if isinstance(other, Prod): + return self.products == other.products and self.products == other.products + else: + return False + + def __repr__(self): + return f"Product({self.products})" + + +class T(Constraint): + """ + True + """ + + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, T) + + def __repr__(self): + return "True" + + +class F(Constraint): + """ + False + """ + + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, F) + + def __repr__(self): + return "False" + + +class BinaryConstraint(Constraint): + """ + Represents all binary operations + """ + + def __init__(self, lhs, rhs, op): + """ + :param lhs: lhs of the constraint + :param rhs: rhs of the constraint + :param op: string representing the operation + """ + self.lhs = lhs + self.rhs = rhs + self.op = op + + def __eq__(self, other): + if isinstance(other, BinaryConstraint): + return ( + self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + ) + else: + return False + + def __repr__(self): + return f"({self.lhs} {self.op} {self.rhs})" + + +class BinConstraintT(BinaryConstraint): + """ + Binary constraints about tensors + """ + + def __init__(self, lhs, rhs, op): + assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and ( + isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn + ) + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + +class BinConstraintD(BinaryConstraint): + """ + Binary constraints about dimensions + """ + + def __init__(self, lhs, rhs, op): + assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) + assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) + + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + +class TGreatestUpperBound(Constraint): + """ + Greatest Upper bound for tensors with dynamic type + """ + + def __init__(self, res, rhs1, rhs2): + """ + :param res: tensor variable that stores the result of the outout + :param rhs1: tensor or tensor variable + :param rhs2: tensor or tensor variabke + """ + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}" + + def __eq__(self, other): + if isinstance(other, TGreatestUpperBound): + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) + else: + return False + + +class DGreatestUpperBound(Constraint): + """ + Greatest Upper bound for dimensions + """ + + def __init__(self, res, rhs1, rhs2): + """ + :param res: Dimension variable to store the result + :param rhs1: dimension variable 1 + :param rhs2: dimension variable 2 + """ + assert is_dim(res) + assert is_dim(rhs1) + assert is_dim(rhs2) + + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f"{self.res} = {self.rhs1}\u2294{self.rhs2}" + + def __eq__(self, other): + if isinstance(other, DGreatestUpperBound): + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) + else: + return False + + +class CanReshape(Constraint): + """ + can_reshape constraint + """ + + def __init__(self, src, target): + """ + :param src: tensor variable + :param target: tensor + """ + self.src = src + self.target = target + + def __repr__(self): + return f"can-reshape({self.src}, {self.target})" + + def __eq__(self, other): + if isinstance(other, CanReshape): + return self.src == other.src and self.target == other.target + else: + return False + + +class IndexSelect(Constraint): + def __init__(self, tensor_size, input_var, dim_replace, index, output): + """ + Args: + input_var: input to index_select + tensor_size: tensor size we are considering + dim_replace: the dimension of the output at "index" + index: location of the dimensions to replace in the input + output: variable to store the result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(dim_replace, DVar) or dim_replace == Dyn + assert isinstance(index, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.dim_replace = dim_replace + self.index = index + self.output = output + + def __repr__(self): + return ( + f" {self.output} = " + f"IndexSelect({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.dim_replace}, " + f"{self.index})" + ) + + def __eq__(self, other): + if isinstance(other, IndexSelect): + return ( + self.tensor_size == other.tensor_size + and self.dim_replace == other.dim_replace + and self.index == other.index + and self.output == other.output + and self.input_var == other.input_var + ) + else: + return False + + +class Transpose(Constraint): + def __init__(self, tensor_size, input_var, index1, index2, output): + """ + Args: + tensor_size: current tensor size + input_var: variable to hold input + index1: dimension 1 + index2: dimension 2 + output: output that stores result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(index1, int) + assert isinstance(index2, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.index1 = index1 + self.index2 = index2 + self.output = output + + def __repr__(self): + return ( + f" {self.output} = " + f"Transpose({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.index1}, " + f"{self.index2})" + ) + + def __eq__(self, other): + if isinstance(other, Transpose): + return ( + self.tensor_size == other.tensor_size + and self.index1 == other.index1 + and self.index2 == other.index2 + and self.output == other.output + and self.input_var == other.input_var + ) + else: + return False + + +class GetItem(Constraint): + def __init__(self, tensor_size, index, res, input_var): + """ + Constraint for getting item given a tensor size + :param tensor_size: actual number + :param index: actual number representing the index + :param res: dimension variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, DVar) + + self.res = res + self.tensor_size = tensor_size + self.index = index + self.input_var = input_var + + def __repr__(self): + return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})" + + def __eq__(self, other): + if isinstance(other, GetItem): + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index == other.index + and self.input_var == other.input_var + ) + else: + return False + + +class GetItemTensor(Constraint): + def __init__(self, tensor_size, index_tuple, res, input_var): + """ + Constraint for getting item given a tensor size + However, when the argument is a tuple, we will + expect a tensor + :param tensor_size: actual number representing the rank + :param index_tuple: tuple for indexing + :param res: tensor variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, TVar) + + self.res = res + self.tensor_size = tensor_size + self.index_tuple = index_tuple + self.input_var = input_var + + def __repr__(self): + return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})" + + def __eq__(self, other): + if isinstance(other, GetItemTensor): + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index_tuple == other.index_tuple + and self.input_var == other.input_var + ) + else: + return False + + +class CalcConv(Constraint): + def __init__( + self, + conv_result, + input_var, + c_out, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): + """ + :param conv_result: the convolution result + :param input_var: input to convolution + :param c_out: output chanel type + :param kernel: kernel tuple + """ + self.conv_result = conv_result + self.input_var = input_var + self.c_out = c_out + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return ( + f"{self.conv_result} =" + f" calc-conv({self.input_var}," + f" {self.c_out}, {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) + + def __eq__(self, other): + if isinstance(other, CalcConv): + return ( + self.conv_result == other.conv_result + and self.input_var == other.input_var + and self.c_out == other.c_out + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation + and self.matching_constraint == other.matching_constraint + ) + else: + return False + + +class CalcMaxPool(Constraint): + def __init__( + self, + maxpool_result, + input_var, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): + """ + :param maxpool_result: the result of maxpool + :param input_var: input to convolution + :param kernel: kernel tuple + """ + self.maxpool_result = maxpool_result + self.input_var = input_var + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return ( + f"{self.maxpool_result} =" + f" calc-maxpool({self.input_var}," + f" {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) + + def __eq__(self, other): + if isinstance(other, CalcMaxPool): + return ( + self.maxpool_result == other.maxpool_result + and self.input_var == other.input_var + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation + and self.matching_constraint == other.matching_constraint + ) + else: + return False + + +class ApplyBroadcasting(Constraint): + def __init__(self, res1, res2, input1, input2): + """ + :param res1: resulting tensor 1 + :param res2: resulting tensor 2 + :param input1: tensor variable 1 + :param input2: tensor variable 2 + """ + self.res1 = res1 + self.res2 = res2 + self.input1 = input1 + self.input2 = input2 + + def __eq__(self, other): + if isinstance(other, ApplyBroadcasting): + return ( + self.res1 == other.res1 + and self.res2 == other.res2 + and self.input1 == other.input1 + and self.input2 == other.input2 + ) + else: + return False + + def __repr__(self): + return ( + f"{self.res1}, {self.res2} =" + f" apply-broadcasting({self.input1}," + f" {self.input2})" + ) + + +class CalcProduct(Constraint): + """ + Given correct dimensions, calculate the product for flatten accounting for Dyn + """ + + def __init__(self, start, end, flattened, dims_to_flatten): + """ + :param start: start index + :param end: end index + :param flattened: variable to store the product + :param dims_to_flatten: the type which we will flatten + """ + assert isinstance(dims_to_flatten, list) + assert isinstance(flattened, TVar) + assert isinstance(start, int) + assert isinstance(end, int) + + self.start = start + self.end = end + self.dims_to_flatten = dims_to_flatten + self.flattened = flattened + + def __eq__(self, other): + if isinstance(other, CalcProduct): + return ( + self.start == other.start + and self.end == other.end + and self.dims_to_flatten == other.dims_to_flatten + and self.flattened == other.flattened + ) + + else: + return False + + def __repr__(self): + return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})" + + +class TVar: + """ + Tensor variable with no tensor constructor + """ + + def __init__(self, tvar): + """ + :param tvar: tensor variable + """ + self.tvar = tvar + + def __repr__(self): + return f"TV({self.tvar})" + + def __eq__(self, other): + if isinstance(other, TVar): + return self.tvar == other.tvar + else: + return False + + +class DVar: + """ + Dimension variable + """ + + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f"DV({self.c})" + + def __eq__(self, other): + if isinstance(other, DVar): + return self.c == other.c + else: + return False + + +class BVar: + """ + Boolean variable + """ + + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f"BV({self.c})" + + def __eq__(self, other): + if isinstance(other, BVar): + return self.c == other.c + else: + return False + + +def is_algebraic_expression(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] + else: + return isinstance(constraint, Prod) + + +def is_bool_expr(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_gt, op_lt, op_neq, op_eq] + else: + return isinstance(constraint, (BVar, Conj, Disj)) + + +def is_dim(d): + return isinstance(d, (DVar, int)) or d == Dyn diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..03346b800924e5db336579e724bd823757acb9b5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -0,0 +1,1562 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections.abc import Iterable +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +import torch +from torch.fx._symbolic_trace import _assert_is_none +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + BinConstraintT, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_matching, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_bvar, + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, + gen_tvar, +) +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_INFERENCE_RULES: dict[Target, Callable] = {} + +MAX_TENSOR_RANK = 4 + +__all__ = [ + "ConstraintGenerator", + "adaptive_inference_rule", + "add_layer_norm_constraints", + "add_linear_constraints", + "arange_inference_rule", + "assert_inference_rule", + "batchnorm_inference_rule", + "bmm_inference_rule", + "broadcasting_inference_rule", + "conv2d_inference_rule", + "cumsum_inference_rule", + "embedding_inference_rule", + "embedding_inference_rule_functional", + "eq_inference_rule", + "equality_inference_rule", + "expand_inference_rule", + "flatten_inference_rule", + "full_inference_rule", + "gen_broadcasting_constraints", + "gen_embedding_rules", + "gen_layer_norm_constraints", + "generate_flatten_constraints", + "get_attr_inference_rule", + "getitem_inference_rule", + "gt_inference_rule", + "index_select_inference_rule", + "layer_norm_functional", + "layer_norm_inference_rule", + "linear_constraints", + "linear_inference_rule", + "lt_inference_rule", + "masked_fill_inference_rule", + "maxpool_inference_rule", + "neq_inference_rule", + "range_check", + "register_inference_rule", + "relu_inference_rule", + "reshape_inference_rule", + "size_inference_rule", + "tensor_inference_rule", + "torch_dim_inference_rule", + "torch_linear_inference_rule", + "transpose_inference_rule", + "type_inference_rule", + "view_inference_rule", +] + + +def register_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _INFERENCE_RULES: + raise RuntimeError(f"Inference rule already registered for {call_target}!") + _INFERENCE_RULES[call_target] = fn + return fn + + return register + + +def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): + d, counter = gen_tensor_dims(n, counter) + c1 = BinConstraintT(input, TensorType(d), op_eq) + start_dim = n if start_dim == -1 else abs(start_dim) + end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 + c2 = CalcProduct(start_dim, end_dim, flattened, d) + nat_constraints = gen_nat_constraints(d) + return Conj([c1, c2, *nat_constraints]), counter + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, symbols, constraints, counter): + """ + If the attribute is "device" then the tensor shape is preserved + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], str) + output, counter = gen_tvar(counter) + symbols[n] = output + + input = symbols[n.args[0]] + attr = n.args[1] + + if attr == "device": + return [BinConstraintT(input, output, op_eq)], counter + else: + raise NotImplementedError("Not yet implemented") + + +@register_inference_rule(torch.bmm) +def bmm_inference_rule(n: Node, symbols, constraints, counter): + """ + Constraints that match the input to a size 3 tensor + and switch the dimensions according to the rules + of batch multiplication + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + bmm_output, counter = gen_tvar(counter) + symbols[n] = bmm_output + + bmm_input1 = symbols[n.args[0]] + bmm_input2 = symbols[n.args[1]] + + dims_input1, counter = gen_tensor_dims(3, counter) + dims_input2, counter = gen_tensor_dims(3, counter) + + inputs_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq), + ] + ) + + input1_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq + ), + ] + ) + + input2_dyn = Conj( + [ + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq + ), + ] + ) + + consistency_constraints = [ + BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) + ] + + batch_size, counter = gen_dvar(counter) + + inputs_are_tensors = Conj( + [ + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, + TensorType([batch_size, dims_input1[1], dims_input2[2]]), + op_eq, + ), + *consistency_constraints, + DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]), + ] + ) + + return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter + + +@register_inference_rule("index_select") +def index_select_inference_rule(n: Node, symbols, constraints, counter): + """ + We constrain the second argument to a vector or Dyn. + The output replaces the input with the shape of the vector + at the position given by the index (first argument) + """ + # print(n.args) + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], Node) + + index_select, counter = gen_tvar(counter) + symbols[n] = index_select + + dims, counter = gen_tensor_dims(1, counter) + + # equality constraint + is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) + is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) + + c2 = Conj( + [ + is_size_1, + Disj( + [ + IndexSelect( + i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select + ) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + c3 = Conj( + [ + is_dyn, + Disj( + [ + IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + + return [Disj([c2, c3])], counter + + +@register_inference_rule("expand") +def expand_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the exact constraints as we do for tensor additions but we constraint + the rank of this expression to be equal to len(n.args[1:]) so that only + those cases get considered for the output + """ + assert isinstance(n.args[0], Node) + + # define the output for expand + expand, counter = gen_tvar(counter) + symbols[n] = expand + + # since we do not have two nodes here, we will construct an argument variable + e1 = symbols[n.args[0]] + e2, counter = gen_tvar(counter) + + e2_nat_constraints = [] + for arg in n.args[1:]: + assert isinstance(arg, (Node, int)) + if isinstance(arg, Node): + assert isinstance(symbols[arg], DVar) + e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) + + e2_constraint = BinConstraintT( + e2, + TensorType( + [arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]] + ), + op_eq, + ) + + constraints, counter = gen_broadcasting_constraints( + e1, e2, symbols, counter, expand + ) + + # constraint the output size + dims, counter = gen_tensor_dims(len(n.args[1:]), counter) + nat_constraints = gen_nat_constraints(dims) + c = [ + BinConstraintT(expand, TensorType(dims), op_eq), + *nat_constraints, + e2_constraint, + *e2_nat_constraints, + ] + constraints += c + + return constraints, counter + + +@register_inference_rule(torch.nn.functional.gelu) +@register_inference_rule(torch.nn.functional.dropout) +@register_inference_rule(torch.nn.functional.softmax) +@register_inference_rule("detach") +@register_inference_rule("to") +@register_inference_rule("int") +@register_inference_rule("long") +@register_inference_rule("contiguous") +@register_inference_rule(torch.ones) +@register_inference_rule(torch.zeros) +def equality_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + output, counter = gen_tvar(counter) + symbols[n] = output + + if isinstance(n.args[0], Node): + input = symbols[n.args[0]] + if isinstance(input, TVar): + return [BinConstraintT(input, output, op_eq)], counter + + # then we have dimension variables + else: + for arg in n.args: + assert isinstance(symbols[arg], DVar) + my_size = [symbols[arg] for arg in n.args] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + + elif isinstance(n.args[0], tuple): + # then the tuple is the size + assert len(n.args[0]) <= 4 + my_size = [symbols[arg] for arg in n.args[0]] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule("transpose") +def transpose_inference_rule(n: Node, symbols, constraints, counter): + """ + Can be considered as a sequence of two index selects, so we generate constraints accordingly + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + assert isinstance(from_arg, TVar) + + # input and output are dyn + is_dyn = Conj( + [BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)] + ) + + # or input is a tensor and we actually do the replacement + c3 = Disj( + [ + Transpose(i + 1, from_arg, n.args[1], n.args[2], output) + for i in range(MAX_TENSOR_RANK) + ] + ) + + return [Disj([is_dyn, c3])], counter + + +@register_inference_rule("type_as") +def type_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + to_arg = symbols[n.args[1]] + + assert isinstance(from_arg, TVar) + assert isinstance(to_arg, TVar) + + return [ + BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq), + ], counter + + +@register_inference_rule("masked_fill_") +def masked_fill_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to addition. For now we implement the constraints when + the argument is a boolean tensor. There is also a case for when + it is a condition. We will leave this out for now. + """ + + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + # We will retrieve the type variables from the symbol table + # and confirm they are tensor variables + + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + if isinstance(e1, TVar) and isinstance(e2, TVar): + masked_fill_tensor, counter = gen_tvar(counter) + symbols[n] = masked_fill_tensor + return gen_broadcasting_constraints( + e1, e2, symbols, counter, masked_fill_tensor + ) + else: + raise NotImplementedError("Not yet implemented") + + +@register_inference_rule(torch.nn.functional.embedding) +def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + embedding_dim_weights = symbols[n.args[1]] + + # will treat this as a static shape. So we will not use matching. + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT( + embedding_dim_weights, TensorType(weight_dims), op_eq + ) + embedding_dim = weight_dims[1] + constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) + return [equality_constraint] + constraints, counter + + +@register_inference_rule(torch.nn.modules.sparse.Embedding) +def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + The output shape differs from the input shape in the last dimension + """ + assert isinstance(n.args[0], Node) + return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) + + +def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): + embedding_output, counter = gen_tvar(counter) + symbols[n] = embedding_output + embedding_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) + output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + + for i in range(1, MAX_TENSOR_RANK): + new_dims, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims) + + # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases + c_tensor_i = Conj( + [ + BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT( + embedding_output, TensorType(new_dims + [embedding_dim]), op_eq + ), + ] + + nat_constraints + ) + c2.append(c_tensor_i) + + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.tensor) +def tensor_inference_rule(n: Node, symbols, constraints, counter): + """ + If the tensor is a scalar, we will skip it since we + do not support scalars yet. We will add support in the future + if it's needed. For our examples so far, scalars are not needed. + """ + return [], counter + + +@register_inference_rule("reshape") +@register_inference_rule("view") +def view_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to reshape but with an extra condition on the strides + """ + assert isinstance(n.args[0], Node) + + # generate the new variable + my_view, counter = gen_tvar(counter) + symbols[n] = my_view + + src_var = symbols[n.args[0]] + t2 = [ + symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:] + ] # target shape + t2_type = [] + num_constraints = [] + + for t in t2: + if t == -1: + var, counter = gen_dvar(counter) + t2_type.append(var) + num_constraints.append(BinConstraintD(var, Dyn, op_neq)) + + else: + num_constraints.append(BinConstraintD(t, Dyn, op_neq)) + t2_type.append(t) + + t2_type = TensorType(t2_type) # type: ignore[assignment] + + c1 = BinConstraintT(my_view, t2_type, op_eq) + c2 = CanReshape(src_var, t2_type) + + # TODO: add the extra check mentioned here: + # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view + + return [c1, c2] + num_constraints, counter # type: ignore[operator] + + +@register_inference_rule("size") +def size_inference_rule(n: Node, symbols, constraints, counter): + """ + The constraint is just lhs = rhs. + Ex: size = input_ids.size() + """ + + if len(n.args) == 1: + # generate the new variable + size, counter = gen_tvar(counter) + symbols[n] = size + input = symbols[n.args[0]] + c = BinConstraintT(input, size, op_eq) + return [c], counter + + elif len(n.args) == 2: + # TODO: review this rule; should input = dyn; output = dyn be included here? + if isinstance(n.args[1], int): + # generate the new variable + size_index, counter = gen_dvar(counter) + symbols[n] = size_index + input = symbols[n.args[0]] + c2 = [ + GetItem(i + 1, n.args[1], size_index, input) + for i in range(MAX_TENSOR_RANK) + ] + c3 = BinConstraintD(0, size_index, op_leq) + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(size_index, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + else: + raise NotImplementedError + + else: + raise NotImplementedError + + +def range_check(i, n): + """ + Checks if an index i is within range of a size n list + Args: + i: index + n: list size + + Returns: Boolean + """ + if i >= 0: + return T() if i < n else F() + else: + return T() if i >= n else F() + + +@register_inference_rule(torch.cumsum) +def cumsum_inference_rule(n: Node, symbols, constraints, counter): + """ + Input and output shapes should be equal + We should verify that the index is valid + """ + assert isinstance(n.args[0], Node) + arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] + assert isinstance(arg_1, int) + + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq), + ] + + [range_check(arg_1, i)] + + nat_constraints + ) + + c2.append(c_tensor_i) + dyn_or_tensor = Disj([c1, Disj(c2)]) + return [dyn_or_tensor], counter + + +@register_inference_rule(_assert_is_none) +def assert_inference_rule(n: Node, symbols, constraints, counter): + assert len(n.users) == 0 + return [], counter + + +@register_inference_rule(operator.getitem) +def getitem_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # dimension output case + if isinstance(n.args[1], int): + # create and store the new dimension variable + get_item_output, counter = gen_dvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + # if the input is dynamic, we accept any index and return + # a dynamic dimension as output + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + # if the input is a tensor, + # generate a getItem constraint which will be expanded based on the + # tensor dimension. + + c2 = [ + GetItem(i + 1, n.args[1], get_item_output, get_item_arg) + for i in range(MAX_TENSOR_RANK) + ] + + # since the output is a dimension, we make sure it's a natural number + # added as a conjunction to the disjunction of c2 + c3 = BinConstraintD(0, get_item_output, op_leq) + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + # tensor output case + elif isinstance(n.args[1], tuple): + # create and store the new tensor variable + get_item_output, counter = gen_tvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + if n.args[0] in symbols: + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] + c1 = Conj([input_dyn, output_dyn]) + + c2 = [ + GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] + for i in range(MAX_TENSOR_RANK) + ] + else: + # TODO: we should figure out why there is a key-error here. + return [], counter + + return [Disj([c1, *c2])], counter + + else: + raise RuntimeError("Method not yet implemented") + + +@register_inference_rule(operator.gt) +def gt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + gt_tensor, counter = gen_tvar(counter) + symbols[n] = gt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + elif isinstance(e1, TVar) and isinstance(e2, int): + # then we made the wrong assumption about the argument being a tensor + # so we should fix the assumption + warnings.warn( + f"Made the wrong assumption for node {n}. Correctness not guaranteed." + ) + + new_e1, counter = gen_dvar(counter) + symbols[n.args[0]] = new_e1 + symbols[n.args[0]] + + gt_constraint = BinConstraintD(new_e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise NotImplementedError("Method not yet implemented") + + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(operator.eq) +def eq_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + eq_tensor, counter = gen_tvar(counter) + symbols[n] = eq_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError("Method not yet implemented") + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(operator.ne) +def neq_inference_rule(n: Node, symbols, constraints, counter): + """ + Translates to inconsistent in gradual types. + To prove inequality, we should prove that + tensors are either different sizes or + disagree on at least one dimension + + This is a WIP (works when the condition + is false. We are working on making this operation work + when the condition is true as well) + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], tuple) + + # implementing for size 3 and 4 + if len(n.args[1]) == 3: + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + + lhs = symbols[n.args[0]] + + b, counter = gen_tensor_dims(4, counter) + input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b[0], op_neq) + neq_2 = BinConstraintD(d2, b[1], op_neq) + neq_3 = BinConstraintD(d3, b[2], op_neq) + + # dimensions inconsistent + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3] + ) + + dims_inconsistent = Disj( + [dims_inconsistent1, dims_inconsistent2, dims_inconsistent3] + ) + + # we are covering size 3 and 4 only for now + ne_constraint = Conj([input_is_size3, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + elif len(n.args[1]) == 4: + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + assert isinstance(n.args[1][3], (Node, int)) + + lhs = symbols[n.args[0]] + + b1, counter = gen_dvar(counter) + b2, counter = gen_dvar(counter) + b3, counter = gen_dvar(counter) + b4, counter = gen_dvar(counter) + + input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b1, op_neq) + neq_2 = BinConstraintD(d2, b2, op_neq) + neq_3 = BinConstraintD(d3, b3, op_neq) + neq_4 = BinConstraintD(d4, b4, op_neq) + + # dimensions to inconsistent + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3] + ) + dims_inconsistent4 = Conj( + [BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4] + ) + + dims_inconsistent = Disj( + [ + dims_inconsistent1, + dims_inconsistent2, + dims_inconsistent3, + dims_inconsistent4, + ] + ) + + ne_constraint = Conj([input_is_size4, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + else: + raise NotImplementedError("Method not yet implemented") + + return [equality_constraint], counter + + +@register_inference_rule(operator.lt) +def lt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + lt_tensor, counter = gen_tvar(counter) + symbols[n] = lt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError("Method not yet implemented") + + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(torch.full) +def full_inference_rule(n: Node, symbols, constraints, counter): + full, counter = gen_tvar(counter) + symbols[n] = full + res = [] + + assert isinstance(n.args[0], Iterable) + for arg in n.args[0]: + dim = arg if isinstance(arg, int) else symbols[arg] + res.append(dim) + c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] + return [c], counter + + +# TODO normalize index +@register_inference_rule(torch.arange) +def arange_inference_rule(n: Node, symbols, constraints, counter): + start = 0 + step = 1 + + if len(n.args) == 1: + end = symbols[n.args[0]] + else: + raise NotImplementedError("Not yet implemented") + + # int((end - start) / step) + d1, counter = gen_dvar(counter) + size_constraint = BinConstraintD( + d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq + ) + arange, counter = gen_tvar(counter) + symbols[n] = arange + + # either the a parameter is a number or it is Dyn + c1 = Disj( + [ + BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq), + ] + ) + c2 = BinConstraintD(d1, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + c11 = Conj( + [ + BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq), + ] + ) + c22 = BinConstraintD(d1, Dyn, op_neq) + both_numbers = Conj([c11, c22, size_constraint]) + + return [ + BinConstraintT(arange, TensorType([d1]), op_eq), + Disj([both_dyn, both_numbers]), + ], counter + + +def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): + # additional vars that don't correspond to expressions + e11, counter = gen_tvar(counter) + e22, counter = gen_tvar(counter) + + # generate constraints + c1 = TGreatestUpperBound(output_var, e11, e22) + c2 = ApplyBroadcasting(e11, e22, e1, e2) + c3 = BinConstraintT(e11, e22, op_consistency) + return [c1, c2, c3], counter + + +@register_inference_rule(operator.mul) +@register_inference_rule(torch.ne) +@register_inference_rule("ne") +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def broadcasting_inference_rule(n: Node, symbols, constraints, counter): + op_code = None + if n.target == operator.add or n.target == torch.add: + op_code = op_add + elif n.target == operator.mul: + op_code = op_mul + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(symbols[n.args[0]], TVar) and isinstance( + symbols[n.args[1]], TVar + ): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) + else: + raise NotImplementedError("Method not yet implemented") + + elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)): + if isinstance(symbols[n.args[0]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + return [BinConstraintT(my_output, e1, op_eq)], counter + elif isinstance(symbols[n.args[0]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + + # we will propagate the runtime value here since this is regular addition + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e1, n.args[1], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) + return [c], counter + + elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)): + if isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + return [BinConstraintT(my_output, e2, op_eq)], counter + elif isinstance(symbols[n.args[1]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + + # we will propagate the runtime value here since this is regular addition + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e2, n.args[0], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) + return [c], counter + + else: + raise NotImplementedError("Method not yet implemented") + + else: + # TODO generate add constraints for scalar addition + raise NotImplementedError("Addition not yet implemented") + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + flattened, counter = gen_tvar(counter) + symbols[n] = flattened + + input = symbols[n.args[0]] + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + c1 = BinConstraintT(input, Dyn, op_eq) + c2 = BinConstraintT(flattened, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + const = [] + for i in range(1, MAX_TENSOR_RANK + 1): + c, counter = generate_flatten_constraints( + start_dim, end_dim, input, flattened, i, counter + ) + const.append(c) + + return [Disj([both_dyn, *const])], counter + + +@register_inference_rule(torch.nn.functional.layer_norm) +def layer_norm_functional(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints(n, n.args[1], symbols, counter) + + +@register_inference_rule(torch.nn.LayerNorm) +def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + Input should be consistent with the normalized_shape + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints( + n, module_instance.normalized_shape, symbols, counter + ) + + +def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims_rhs) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq), + ] + + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + + nat_constraints + ) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.nn.Dropout) +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + assert isinstance(input, TVar) + return [BinConstraintT(input, output, op_eq)], counter + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output sizes should be the same except for the last dimension + If the input is Dyn, then so should the output + """ + assert isinstance(n.args[0], Node) + return linear_constraints( + n, module_instance.in_features, module_instance.out_features, symbols, counter + ) + + +@register_inference_rule("dim") +def torch_dim_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + my_dim, counter = gen_dvar(counter) + symbols[n] = my_dim + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(my_dim, Dyn, op_eq) + + c1 = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintD(my_dim, i, op_eq), + ] + ) + c1.append(c_tensor_i) + + return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter + + +@register_inference_rule(torch._C._nn.linear) +def torch_linear_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT( + symbols[n.args[1]], TensorType(weight_dims), op_eq + ) + constraints, counter = linear_constraints( + n, weight_dims[1], weight_dims[0], symbols, counter + ) + return [equality_constraint] + constraints, counter + + +def linear_constraints(n: Node, in_features, out_features, symbols, counter): + linear_output, counter = gen_tvar(counter) + symbols[n] = linear_output + linear_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(linear_input, Dyn, op_eq) + output_dyn = BinConstraintT(linear_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj( + [ + BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq), + ] + + add_linear_constraints( + new_dims_rhs_1, new_dims_rhs_2, in_features, out_features + ) + + nat_constraints + ) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + + +def add_layer_norm_constraints(input_dim, normalized_dim): + """ + The constraints say that the type has te form: [*, 1024, 1024] + while the normalized_dim have the form [1024, 1024] + Args: + input_dim: Input shape of layer norm + normalized_dim: normalized_dim parameter of the module instance + + """ + + # in this case we return false since there's a pattern mismatch + if len(normalized_dim) > len(input_dim): + return [F()] + + else: + constraints = [] + for i, n in zip(reversed(input_dim), reversed(normalized_dim)): + constraints.append(BinConstraintD(i, n, op_consistency)) + return constraints + + +def add_linear_constraints(dims1, dims2, in_features, out_features): + assert len(dims1) == len(dims2) + constraints = [] + for i in range(len(dims1)): + if i == len(dims1) - 1: + constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) + constraints.append(BinConstraintD(dims2[i], out_features, op_eq)) + else: + constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) + + return constraints + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + my_reshape, counter = gen_tvar(counter) + symbols[n] = my_reshape + + src_var = symbols[n.args[0]] + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] + c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] + c2 = CanReshape(src_var, t2_type) + + return [c1, c2], counter + + +@register_inference_rule(BatchNorm2d) +def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + batchnorm_output, counter = gen_tvar(counter) + symbols[n] = batchnorm_output + batchnorm_input = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + avg_pool, counter = gen_tvar(counter) + + symbols[n] = avg_pool + input_var = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT( + avg_pool, + TensorType( + [d1, d2, module_instance.output_size[0], module_instance.output_size[1]] + ), + op_eq, + ) + + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + my_conv, counter = gen_tvar(counter) + symbols[n] = my_conv + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + # c2 = DConsistency(module_instance.in_channels, d2) + c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) + + c3 = CalcConv( + my_conv, + input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, c3, *nat_constraints], counter + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + maxpool, counter = gen_tvar(counter) + symbols[n] = maxpool + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + c2 = CalcMaxPool( + maxpool, + input_var, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, *nat_constraints], counter + + +class ConstraintGenerator: + def __init__(self, traced, graph=None): + self.traced = traced # traced or tracer.root + self.traced_params = dict(self.traced.named_parameters()) + self.constraints = [] + self.symbol_dict = {} + self.graph = traced.graph if hasattr(traced, "graph") else graph + + def generate_constraints(self, counter=0): + """ + Iterate through every node and generate constraints + Effect: self.constraints will be populated with the final constraints + """ + graph = self.graph + + all_constraints = [] + + for n in graph.nodes: + (constraints, counter) = self.generate_constraints_node(n, counter) + all_constraints += constraints + + return Conj(all_constraints), counter + + def generate_constraints_node(self, n: Node, counter): + """ + Generate constraints the given node: + Currently supported operations: + - Reshape + - Add + - conv2d + """ + + if n.op == "placeholder": + x, counter = gen_tvar(counter) + self.symbol_dict[n] = x + + my_type = n.type + + if n.type != Dyn and (not isinstance(n.type, TensorType)): + if n.type == torch.nn.parameter.Parameter: + # since we have a parameter, the shape must be static + assert "example_value" in n.meta + my_type = TensorType(n.meta["example_value"].size()) + else: + my_type = Dyn + + c1 = BinConstraintT(my_type, x, op_precision) + c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) + return [c1, c2], counter + + elif n.op == "call_function": + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)]( + n, module_instance, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "call_method": + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "get_attr": + t = self.traced_params.get(n.target, None) + + if isinstance(t, torch.Tensor): + if len(t.shape) > 0: + res = list(t.shape) + attr_type = TensorType(res) + output, counter = gen_tvar(counter) + self.symbol_dict[n] = output + return [BinConstraintT(output, attr_type, op_eq)], counter + else: + # scalar? + return [], counter + else: + return [], counter + + elif n.op == "output": + return [], counter + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..11ebff0102093c035576eda29e71bfd9ccb87ce6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -0,0 +1,1322 @@ +# mypy: ignore-errors +import copy +import itertools +from typing import Callable + +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + Constraint, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + Prod, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + BinConstraintT, + MAX_TENSOR_RANK, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_leq, + op_matching, + op_mod, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, +) +from torch.fx.tensor_type import Dyn, TensorType + + +_TRANSFORMATION_RULES: dict[Constraint, Callable] = {} + + +def register_transformation_rule(call_target): + def register(fn): + if call_target in _TRANSFORMATION_RULES: + raise RuntimeError( + f"Transformation rule already registered for {call_target}!" + ) + _TRANSFORMATION_RULES[call_target] = fn + return fn + + return register + + +def valid_index(index, dims): + """ + Given a list of dimensions, checks if an index is valid in the list + """ + try: + dims[index] + return T() + except IndexError: + return F() + + +@register_transformation_rule(Transpose) +def transform_transpose(constraint, counter): + """ + Similar to a sequence of two index-selects + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index1 = valid_index(constraint.index1, dims) + is_valid_index2 = valid_index(constraint.index2, dims) + new_dims = copy.deepcopy(dims) + nat_constraints = gen_nat_constraints(dims) + + if is_valid_index1 == T() and is_valid_index2 == T(): + new_dims[constraint.index1] = dims[constraint.index2] + new_dims[constraint.index2] = dims[constraint.index1] + + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index1, + is_valid_index2, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) + return transformed_constraint, counter + + +@register_transformation_rule(IndexSelect) +def transform_index_select(constraint, counter): + """ + The constraints consider the given tensor size, checks if the index is valid + and if so, generates a constraint for replacing the input dimension + with the required dimension + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index = valid_index(constraint.index, dims) + nat_constraints = gen_nat_constraints(dims) + + # if the index is valid then replace the input dimension with the new dimension + # otherwise the dimension will not be replaced and the clause will contain False + if is_valid_index == T(): + new_dims = copy.deepcopy(dims) + new_dims[constraint.index] = constraint.dim_replace + + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) + + # print(constraints) + return transformed_constraint, counter + + +@register_transformation_rule(GetItem) +def transform_get_item(constraint, counter): + """ + generate an equality of the form: + t = [a1, ..., an] + then generate constraints that check if the given index is valid + given this particular tensor size. + If the index is valid, generate a constraint to get the item + Note that we already handled the Dyn input case in the previous + step. + Args: + constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) + counter: variable tracking + Returns: simplified constraints for GetItem + + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + is_valid_index = valid_index(constraint.index, dims) + + all_constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + ] + + # if the index is valid, we generate a constraint for getting an item + # otherwise this clause will have been UNSAT due to the wrong index + if is_valid_index == T(): + all_constraints.append( + BinConstraintD(constraint.res, dims[constraint.index], op_eq) + ) + + return Conj(all_constraints), counter + + +def valid_index_tensor(index, dims): + """ + if the slice instances exceed the length of the dimensions + then this is a type error so we return False + """ + slice_count = 0 + for s in index: + if isinstance(s, slice): + slice_count += 1 + if slice_count > len(dims): + return F() + else: + return T() + + +@register_transformation_rule(GetItemTensor) +def transform_get_item_tensor(constraint, counter): + """ + When the index is a tuple, then the output will be a tensor + TODO: we have to check if this is the case for all HF models + + The cases we are covering here are a tuple with one of: + - slice with default argument + - None + + None appends 1 to the input tensor dimensions + so each occurrence of 'None' increases the rank by 1 + + slice with default arguments does not change the rank + """ + assert isinstance(constraint.index_tuple, tuple) + + # generate a result tensor of the expected size + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + # generate a place-holder list of the right rank + # where "slice" does not contribute to the rank and "None" does + none_c = constraint.index_tuple.count(None) + resulting_tensor_dims = (none_c + len(dims)) * [None] + + dim_index = 0 + for i in range(len(constraint.index_tuple)): + # append 1 to the right location of the resulting tensor + if constraint.index_tuple[i] is None: + resulting_tensor_dims[i] = 1 + + elif constraint.index_tuple[i] == slice(None, None, None): + pass + + else: + raise NotImplementedError("Method not yet implemented") + + # append the remaining dimensions to the right location + dim_index = 0 + for i in range(len(resulting_tensor_dims)): + if resulting_tensor_dims[i] is None: + resulting_tensor_dims[i] = dims[dim_index] + dim_index += 1 + + # check if the index is valid + is_valid_index = valid_index_tensor(constraint.index_tuple, dims) + + # check if the resulting tensor is within bounds + if len(resulting_tensor_dims) > 4: + return F(), counter + + else: + constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), + *nat_constraints, + is_valid_index, + ] + return Conj(constraints), counter + + +@register_transformation_rule(BinConstraintT) +def generate_binconstraint_t(constraint, counter): + """ + Transform binary constraints for tensors + """ + + # precision constraints + if constraint.op == op_precision: + if constraint.lhs == Dyn: + return T(), counter + elif isinstance(constraint.lhs, TensorType): + is_fully_static = all(d != Dyn for d in constraint.lhs.__args__) + if is_fully_static: + return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter + else: + new_dims = [] + + for _ in range(len(constraint.lhs.__args__)): + dim, counter = gen_dvar(counter) + new_dims.append(dim) + + new_dim_constraints = ( + [ + BinConstraintD(old_dim, new_dim, op_precision) + for new_dim, old_dim in zip(new_dims, constraint.lhs.__args__) + ] + + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + + [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims] + ) + return Conj(new_dim_constraints), counter + + # matching + elif constraint.op == op_matching: + assert isinstance(constraint.rhs, TensorType) + d1 = constraint.rhs.__args__[0] + d2 = constraint.rhs.__args__[1] + d3 = constraint.rhs.__args__[2] + d4 = constraint.rhs.__args__[3] + + conj = [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintD(d1, Dyn, op_eq), + BinConstraintD(d2, Dyn, op_eq), + BinConstraintD(d3, Dyn, op_eq), + BinConstraintD(d4, Dyn, op_eq), + ] + return ( + Disj( + [ + Conj(conj), + BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq), + ] + ), + counter, + ) + + elif constraint.op == op_consistency: + c_dyn = Disj( + [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintT(constraint.rhs, Dyn, op_eq), + ] + ) + ( + ( + c_tensor_1, + c_tensor_2, + c_tensor_3, + c_tensor_4, + ), + counter, + ) = gen_consistency_constraints(constraint, counter) + + return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter + + elif constraint.op == op_leq: + assert isinstance(constraint.rhs, int) + disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] + for i in range(1, constraint.rhs + 1): + dims = [] + for _ in range(1, i + 1): + dim_var, counter = gen_dvar(counter) + dims.append(dim_var) + disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) + return Disj(disj), counter + else: + return constraint, counter + + +@register_transformation_rule(BinConstraintD) +def generate_binconstraint_d(constraint, counter): + """ + Transform binary constraints for dimensions + """ + if constraint.op == op_precision: + if isinstance(constraint.lhs, int): + return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter + elif constraint.lhs == Dyn: + return T(), counter + + elif constraint.op == op_consistency: + return ( + Disj( + [ + BinConstraintD(constraint.lhs, constraint.rhs, op_eq), + BinConstraintD(constraint.rhs, Dyn, op_eq), + BinConstraintD(constraint.lhs, Dyn, op_eq), + ] + ), + counter, + ) + + else: + return constraint, counter + + +@register_transformation_rule(Conj) +def generate_conj(constraint, counter): + """ + Transform conjunctions + """ + new = [] + for c in constraint.conjucts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Conj(new), counter + + +@register_transformation_rule(Disj) +def generate_disj(constraint, counter): + """ + Transform disjunctions + """ + new = [] + for c in constraint.disjuncts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Disj(new), counter + + +@register_transformation_rule(TGreatestUpperBound) +def generate_gub(constraint, counter): + """ + Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound + on dimensions + """ + c1 = Conj( + [ + Disj( + [ + BinConstraintT(constraint.rhs1, Dyn, op_eq), + BinConstraintT(constraint.rhs2, Dyn, op_eq), + ] + ), + BinConstraintT(constraint.res, Dyn, op_eq), + ] + ) + + [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) + + return Disj([c1, c2, c3, c4, c5]), counter + + +@register_transformation_rule(DGreatestUpperBound) +def generate_d_gub(constraint, counter): + """ + Transform greatest upper bound for dimensions into equality constraints + """ + c1 = Conj( + [ + BinConstraintD(constraint.rhs1, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs2, op_eq), + ] + ) + c2 = Conj( + [ + BinConstraintD(constraint.rhs2, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) + c3 = Conj( + [ + BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) + return Disj([c1, c2, c3]), counter + + +@register_transformation_rule(CalcConv) +def generate_calc_conv(constraint, counter): + d, counter = gen_tensor_dims(4, counter) + conv_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the convolution result is a tensor of size 4 + c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) + + # the second dimension of the output is equal to the output channels + c2 = Conj( + [ + BinConstraintD(d[1], constraint.c_out, op_eq), + BinConstraintD(d[1], Dyn, op_neq), + ] + ) + + # the input corresponds to the output in the first dimension of the convolution + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcMaxPool) +def generate_calc_maxpool(constraint, counter): + """ + Transform maxpool constraints + """ + d, counter = gen_tensor_dims(4, counter) + maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the maxpool result is a tensor of size 4 + c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) + + # the input corresponds to the output in the first and second dimension of maxpool + c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcProduct) +def generate_calc_product(constraint, counter): + """ + Transform flatten constraints + """ + start = constraint.start + end = constraint.end + dims = constraint.dims_to_flatten + flattened = constraint.flattened + n = len(constraint.dims_to_flatten) + + # this will be evaluated right here + boundary_check = 0 <= start and start < end and end <= n + + c_boundary = T() if boundary_check else F() + + lhs = dims[0:start] + rhs = dims[end:] + mid = dims[start:end] + + all_possibilities = generate_all_int_dyn_dim_possibilities(mid) + + all_constraints = [] + + for p in all_possibilities: + p = list(p) + # this tells us there is a dynamic variable + contains_dyn = not all(constraint.op == op_neq for constraint in p) + if contains_dyn: + mid_var = [Dyn] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ) + ] + + p + ) + ) + else: + new_var, counter = gen_dvar(counter) + mid_eq_prod = Conj( + [ + BinConstraintD(new_var, Prod(mid), op_eq), + BinConstraintD(new_var, Dyn, op_neq), + ] + ) + mid_var = [new_var] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ), + mid_eq_prod, + ] + + p + ) + ) + + return Conj([Disj(all_constraints), c_boundary]), counter + + +@register_transformation_rule(CanReshape) +def generate_reshape(constraint, counter): + """ + Transform reshape constraints + """ + d, counter = gen_tensor_dims(4, counter) + + d1 = d[0] + d2 = d[1] + d3 = d[2] + d4 = d[3] + + target = constraint.target.__args__ + + is_fully_static = all(d != Dyn for d in target) + + # dynamic tensor + c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) + c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) + c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) + c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) + c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) + + d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) + d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) + + d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) + d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) + + d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + nat_d1 = BinConstraintD(0, d1, op_leq) + nat_d2 = BinConstraintD(0, d2, op_leq) + nat_d3 = BinConstraintD(0, d3, op_leq) + nat_d4 = BinConstraintD(0, d4, op_leq) + + if is_fully_static: + # size 1 tensor + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, BinConstraintD(d1, Prod(target), op_eq)]))] + ) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # size 2 tensor + all_tensor_2 = Conj( + [c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)] + ) + + # size 3 tensor + all_tensor_3 = Conj( + [c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)] + ) + + # size 4 tensor + all_tensor_4 = Conj( + [c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)] + ) + + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) + + # then there must be exactly one occurrence of dyn + else: + new_target = [n for n in target if n != Dyn] + + # tensor 1 + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, is_dim_div_by_target(new_target, d1)]))] + ) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # tensor 2 + c21 = Disj([d1_eq_dyn, d2_eq_dyn]) + c22 = Conj( + [d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))] + ) + all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) + + # tensor 3 + c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) + c32 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3])), + ] + ) + all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) + + # tensor 4 + c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) + c42 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + d4_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4])), + ] + ) + all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) + + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) + + +@register_transformation_rule(ApplyBroadcasting) +def generate_broadcasting(constraint, counter): + """ + Transform broadcasting constraints + """ + e11, e12 = constraint.res1, constraint.res2 + e1, e2 = constraint.input1, constraint.input2 + + e1_dyn = BinConstraintT(e1, Dyn, op_eq) + e2_dyn = BinConstraintT(e2, Dyn, op_eq) + + # Introduce dimensions + e1_equal_e11 = BinConstraintT(e1, e11, op_eq) + e2_equal_e12 = BinConstraintT(e2, e12, op_eq) + + # dyn possibility + e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) + e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) + + # tensor possibility + # generate dimensions to create tensors of size 1 + final_tensor_1_constraint, _, _, nat_dims_1, counter = gen_broadcasting_constraints( + e1, e2, e11, e12, 1, counter + ) + + # generate dimensions to create tensors of size 2 + ( + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + nat_dims_2, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) + + # generate dimensions to create tensors of size 3 + ( + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + nat_dims_3, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) + + # generate dimensions to create tensors of size 4 + ( + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2, + nat_dims_4, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) + + final_result = Disj( + [ + e1_dyn_constraint, + e2_dyn_constraint, + final_tensor_1_constraint, + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2, + ] + ) + + return ( + Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), + counter, + ) + + +def transform_constraint(constraint: Constraint, counter: int): + """ + Transforms a constraint into a simpler constraint. + Ex: precision and consistency are transformed to equality + Args: + constraint: constraint to be transformed + counter: for variable tracking + + Returns: Constraint + + """ + if type(constraint) in _TRANSFORMATION_RULES: + return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) + + else: + return constraint, counter + + +def calc_last_two_dims(constraint, d: list[DVar]): + """ + Generates constraints for the last two dimensions of a convolution or a maxpool output + Args: + constraint: CalcConv or CalcMaxPool + d: The list of output dimensions + + Returns: Constraints for calculating the last two dimensions of the output + + """ + + assert isinstance(constraint, (CalcConv, CalcMaxPool)) + + b3 = constraint.matching_constraint[2] + b4 = constraint.matching_constraint[3] + + b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) + b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) + + d3_not_dyn = Conj( + [BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)] + ) + d4_not_dyn = Conj( + [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)] + ) + + # transform parameters into tuples incase they are not already + padding = ( + (constraint.padding, constraint.padding) + if isinstance(constraint.padding, int) + else constraint.padding + ) + kernel = ( + (constraint.kernel, constraint.kernel) + if isinstance(constraint.kernel, int) + else constraint.kernel + ) + stride = ( + (constraint.stride, constraint.stride) + if isinstance(constraint.stride, int) + else constraint.stride + ) + dilation = ( + (constraint.dilation, constraint.dilation) + if isinstance(constraint.dilation, int) + else constraint.dilation + ) + + f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) + f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) + f3 = BinConstraintD( + BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div + ) + f4 = BinConstraintD(f3, 1, op_add) + + c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) + + f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) + f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) + f33 = BinConstraintD( + BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div + ) + f44 = BinConstraintD(f33, 1, op_add) + + c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) + + return c4, c5 + + +def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]): + """ + Generate all possibilities of being equal or not equal to dyn for my_list + Args: + my_list: List of tensor dimensions + + Returns: A list of a list of constraints. Each list of constraints corresponds to + one possibility about the values of the dimension variables + """ + # generate all possibilities of being equal or not equal to dyn for my_list + eq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list)) + ] + neq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list)) + ] + + d_possibilities = [list(i) for i in zip(eq_possibilities, neq_possibilities)] + all_possibilities = list(itertools.product(*d_possibilities)) + return all_possibilities + + +def is_target_div_by_dim(target: list[int], dim: list[DVar]): + """ + Generate constraints to check if the target dimensions are divisible by the input dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) + + +def is_dim_div_by_target(target: list[int], dim: list[DVar]): + """ + Generate constraints to check if the input dimensions is divisible by the target dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) + + +def gen_all_reshape_possibilities(list_of_dims, target): + """ + Consider all possibilities what the input dimensions could be (number or dynamic) + Then generate the appropriate constraints using multiplication or mod depending on the possibility + The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn + for the input. Target is fixed because at most one dimension could be dyn. + We have different cases for this. + + Args: + list_of_dims: The input list of dimensions + target: The tensor we want to reshape to + + Returns: A disjunction of transformed reshape constraints + + """ + all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) + + all_constraints = [] + + for p in all_possibilities: + to_multiply = [] + + p = list(p) + + for constraint in p: + assert isinstance(constraint, BinConstraintD) + if constraint.op == op_neq: + to_multiply.append(constraint.lhs) + + if not to_multiply: + all_constraints.append(Conj(p)) + + elif len(to_multiply) < len(list_of_dims): + all_constraints.append( + Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]) + ) + else: + all_constraints.append( + Conj(p + [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)]) + ) + + return Disj(all_constraints) + + +def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): + """ + Apply broadcasting to the 'index' dimension of tensor_input1. + Args: + tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 + tensor_input2: represents the second input + res1: broadcasted result 1 + res2: broadcasted result 2 + index: the index to broadcast + padding: If padding was used, then tensor_input1[index] does not exist + + Returns: + + """ + if tensor_input1[index] is None: + assert padding + + if not padding: + # then the inputs are the same length so they all have dimensions at "index" + return Conj( + [ + BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) + + else: + # we don't set the input dimension to 1, since it doesn't exist. + return Conj( + [ + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) + + +def apply_padding( + e1_var: TVar, + e11: BinConstraintT, + e2: BinConstraintT, + e12: BinConstraintT, + d2: list[DVar], + d11: list[DVar], + d12: list[DVar], + counter: int, +): + """ + We are considering the possibility where one input has less dimensions than + another input, so we apply padding to the broadcasted results + + Args: + e1_var: Variable representing the first input where padding will be + e11: constraint of the form e11 = Tensortype[d1, ..., dn] + e2: constraint of the form e2 = Tensortype[d1, ..., dn] + e12: constraint of the form e11 = Tensortype[d1, ..., dn] + d2: Tensor variables for the second input + d11: Tensor variables for the broadcasted first input + d12: Tensor variables for the broadcasted second input + counter: variable tracking + + Returns: A new constraint whose goal is to apply padding to the broadcasted result + + """ + + res = [] + + # pad the shorter input with None so we can pass it to the broadcasting helper function + for i in range(1, len(d2)): + d1, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) + + e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) + + simulate_padding = [None] * (len(d2) - i) + + assert len(simulate_padding + d1) == len(d2) + + # for every padding size, we also consider broadcasting + broadcast_padding = [ + broadcast_dim(simulate_padding, d2, d11, d12, j, True) + for j in range(len(d2) - i) + ] + + # we consider the possibilities for broadcasting for every dimension. Since we already + # padded d1, we do not consider it while broadcasting + all_broadcasting_possibilities = ( + generate_all_broadcasting_possibilities_no_padding( + d1, d2[(len(d2) - i) :], d11[(len(d2) - i) :], d12[(len(d2) - i) :] + ) + ) + # combine all constraints into a conjunction + c = Conj( + [ + e1, + e11, + e2, + e12, + *broadcast_padding, + all_broadcasting_possibilities, + *nat_constraints, + ] + ) + res.append(c) + + return Disj(res), counter + + +def no_broadcast_dim_with_index( + d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int +): + """ + Args: + d1: input 1 + d2: input 2 + d3: simulated broadcasting for input 1 + d4: simulated broadcasting for input 2 + i: the rank of the resulting tensor addition + + Returns: Constraints for when no broadcasting occurs + """ + return Conj( + [ + Disj( + [ + Conj( + [ + BinConstraintD(d1[i], 1, op_eq), + BinConstraintD(d2[i], 1, op_eq), + ] + ), + Conj( + [ + BinConstraintD(d1[i], 1, op_neq), + BinConstraintD(d2[i], 1, op_neq), + ] + ), + ] + ), + BinConstraintD(d1[i], d3[i], op_eq), + BinConstraintD(d2[i], d4[i], op_eq), + ] + ) + + +def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): + """ + Generate lists of DVar to represent tensor dimensions + Args: + num_tensors: the required number of tensors + dim_size: the number of dimensions for each tensor + counter: variable tracking + + Returns: A list of a list of tensor dimensions + + """ + res = [] + + for _ in range(num_tensors): + dims, counter = gen_tensor_dims(dim_size, counter) + res.append(dims) + + return res, counter + + +def create_equality_constraints_for_broadcasting( + e1: TVar, + e2: TVar, + e11: TVar, + e12: TVar, + d1: list[DVar], + d2: list[DVar], + d11: list[DVar], + d12: list[DVar], +): + """ + Create equality constraints for when no broadcasting occurs + Args: + e1: Input 1 + e2: Input 2 + e11: Broadcasted input 1 + e12: Broadcasted input 2 + d1: Variables that store dimensions for e1 + d2: Variables that store dimensions for e2 + d11: Variables that store dimensions for e11 + d12: Variables that store dimensions for e22 + + Returns: Four equality constraints + + """ + + e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) + e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) + e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) + e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) + return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] + + +def gen_consistency_constraints(constraint: Constraint, counter: int): + """ + Args: + constraint: Consistency constraint on tensors + counter: for variable tracking + + Returns: Equality and consistency constraints on dimensions + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj( + [ + BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq), + ] + + [ + BinConstraintD(d1, d2, op_consistency) + for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2) + ] + + nat_constraints + ) + + all_constraints.append(c_tensor_i) + + return all_constraints, counter + + +def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): + """ + Args: + constraint: Greatest upper bound on tensors + counter: variable tracking + + Returns: A set of equality constraints and DGreatestUpperBound constraints + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + c = [] + dims1, counter = gen_tensor_dims(i, counter) + c1tensor = TensorType(dims1) + + dims2, counter = gen_tensor_dims(i, counter) + c2tensor = TensorType(dims2) + + dims3, counter = gen_tensor_dims(i, counter) + c3tensor = TensorType(dims3) + + c += [ + BinConstraintT(constraint.rhs1, c1tensor, op_eq), + BinConstraintT(constraint.rhs2, c2tensor, op_eq), + BinConstraintT(constraint.res, c3tensor, op_eq), + ] + gen_nat_constraints(dims1 + dims2 + dims3) + + assert ( + len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + ) + for i in range(len(c3tensor.__args__)): + c.append( + DGreatestUpperBound( + c3tensor.__args__[i], c1tensor.__args__[i], c2tensor.__args__[i] + ) + ) + + all_constraints.append(Conj(c)) + return all_constraints, counter + + +def generate_all_broadcasting_possibilities_no_padding( + d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar] +): + """ + Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. + We look at all combinations for all dimensions in d1 and d2 + Args: + d1: input1 dimensions + d2: input2 dimensions + d11: broadcasted input1 dimensions + d12: broadcasted input2 dimensions + + Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions + + """ + + size = len(d1) + + res2 = [] + + for i in range(size): + t1 = broadcast_dim(d1, d2, d11, d12, i) + t2 = broadcast_dim(d2, d1, d12, d11, i) + t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) + + res2.append(Disj([t1, t2, t3])) + + return Conj(res2) + + +def gen_broadcasting_constraints( + e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int +): + """ + Simulates broadcasting on e1 and e2 and returns the results + respectively in e11 and e12. Because of gradual types, + e1 and e2 may not be equal. Similarly, e11 and e12 may not + be equal. e11 and e12 should be guaranteed to be consistent + as they represent the shapes of the tensors to be added after + broadcasting. + Args: + e1: TVar representing the type of input 1 + e2: TVar representing the type of input 2 + e11: TVar representing the representing broadcasted input 1 + e12: TVar representing the representing broadcasted input 2 + i: The rank of the resulting type of addition + counter: for variable tracking + + Returns: Simplified broadcasting constraints + + """ + dims, counter = gen_lists_of_dims(4, i, counter) + [d1, d2, d3, d4] = dims + nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims))) + + initialize_tensors_constraints = create_equality_constraints_for_broadcasting( + e1, e2, e11, e12, d1, d2, d3, d4 + ) + + [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints + + # without padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_no_padding = Conj( + [ + *initialize_tensors_constraints, + generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4), + ] + ) + + # with padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_padding_arg1, counter = apply_padding( + e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter + ) + + final_tensor_constraint_padding_arg2, counter = apply_padding( + e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter + ) + + return ( + final_tensor_constraint_no_padding, + final_tensor_constraint_padding_arg1, + final_tensor_constraint_padding_arg2, + nat_dims_i, + counter, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..267100c8545c8b2310299337ecf64211f633f6ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py @@ -0,0 +1,14 @@ +op_add = "+" +op_sub = "-" +op_mul = "*" +op_div = "/" +op_eq = "=" +op_neq = "!=" +op_imp = "=>" +op_matching = "\u22b3" # (contains) +op_consistency = "~" +op_precision = "\u2291" # (square image of or equal to) +op_leq = "\u2264" # less-than or equal to +op_lt = "<" +op_gt = ">" +op_mod = "%" diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f9f33965e07551c651fa560a80c5e263dd5b85 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -0,0 +1,446 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BinConstraintT, + BVar, + Conj, + Disj, + DVar, + F, + is_algebraic_expression, + is_bool_expr, + is_dim, + Prod, + T, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + ConstraintGenerator, +) +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import ( + transform_constraint, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType + + +try: + import z3 # type: ignore[import] + + from torch.fx.experimental.migrate_gradual_types.z3_types import ( + D, + tensor_type, + z3_dyn, + ) + + HAS_Z3 = True + + def transform_to_z3(constraint, counter, dimension_dict): + if isinstance(constraint, Conj): + conjuncts = [] + for c in constraint.conjucts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + conjuncts.append(new_c) + return z3.And(conjuncts), counter + + elif isinstance(constraint, Disj): + disjuncts = [] + for c in constraint.disjuncts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + disjuncts.append(new_c) + return z3.Or(disjuncts), counter + + elif isinstance(constraint, T): + return True, counter + + elif isinstance(constraint, F): + return False, counter + + elif isinstance(constraint, BinConstraintT): + if constraint.op == op_eq: + lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) + return (lhs == rhs), counter + + else: + raise NotImplementedError("Method not yet implemented") + + elif isinstance(constraint, BinConstraintD): + if constraint.op == op_eq: + if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): + transformed_rhs, counter = transform_to_z3( + constraint.rhs, counter, dimension_dict + ) + transformed_lhs = z3.Bool(constraint.lhs.c) + return transformed_lhs == transformed_rhs, counter + + elif is_dim(constraint.lhs) and is_dim(constraint.rhs): + # with dimension transformations we consider the encoding + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) + return lhs == rhs, counter + + else: + # then we have an algebraic expression which means that we disregard the + # first element of the encoding + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs == rhs, counter + + # The assumption here is that the LHS and RHS must be dimensions + elif constraint.op == op_neq: + assert is_dim(constraint.lhs) + assert is_dim(constraint.rhs) + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) + if constraint.rhs == Dyn or constraint.lhs == Dyn: + if constraint.rhs == Dyn: + return lhs.arg(0) == 1, counter + elif constraint.lhs == Dyn: + return rhs.arg(0) == 1, counter + + # if one of the instances is a number + elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): + if isinstance(constraint.lhs, int): + return ( + z3.Or( + [ + rhs.arg(0) == 0, + z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) + + elif isinstance(constraint.rhs, int): + return ( + z3.Or( + [ + lhs.arg(0) == 0, + z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) + + else: + return ( + z3.Or( + [ + z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), + z3.And( + [ + lhs.arg(0) != 0, + rhs.arg(0) != 0, + lhs.arg(1) != rhs.arg(1), + ] + ), + ] + ), + counter, + ) + + elif constraint.op == op_leq: + # if the dimensions are not dyn, this will come into effect + # there would have been another constraint specifying if a given dimension + # is dyn or not + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs <= rhs, counter + + elif constraint.op == op_gt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs > rhs, counter + + elif constraint.op == op_lt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs < rhs, counter + + else: + raise NotImplementedError("operation not yet implemented") + + else: + raise NotImplementedError("Operation not yet implemented") + + def transform_var(tensor, counter, dimension_dict): + """ + Transforms tensor variables to a format understood by z3 + Args: + tensor: Tensor variable or a tensor type potentially with variable dimensions + Returns: Transformed variable to a z3 format + + """ + if isinstance(tensor, TensorType): + res = [] + for t in tensor.__args__: + transformed, counter = transform_dimension(t, counter, dimension_dict) + res.append(transformed) + + assert len(res) <= 4 + if len(tensor.__args__) == 1: + return tensor_type.tensor1(res[0]), counter + elif len(tensor.__args__) == 2: + return tensor_type.tensor2(res[0], res[1]), counter + elif len(tensor.__args__) == 3: + return tensor_type.tensor3(res[0], res[1], res[2]), counter + elif len(tensor.__args__) == 4: + return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter + + elif tensor == Dyn: + return z3_dyn, counter + + elif isinstance(tensor, TVar): + return z3.Const(tensor.tvar, tensor_type), counter + + def transform_dimension(dimension, counter, dimension_dict): + """ + Takes a dimension variable or a number and transforms it to a tuple + according to our scheme + Args: + dimension: The dimension to be transformed + counter: variable tracking + + Returns: tuple and the current counter + + """ + if dimension == Dyn: + counter += 1 + return D(0, z3.Int(counter)), counter + elif isinstance(dimension, int): + return D(1, dimension), counter + elif isinstance(dimension, DVar): + if dimension.c in dimension_dict: + return ( + D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), + counter, + ) + else: + counter += 1 + dimension_dict[dimension.c] = counter + return D(z3.Int(counter), z3.Int(dimension.c)), counter + + def transform_algebraic_expression(expr, counter, dimension_dict): + """ + Transforms an algebraic expression to z3 format + Args: + expr: An expression is either a dimension variable or an algebraic-expression + + + Returns: the transformed expression + + """ + assert is_algebraic_expression(expr) or is_dim(expr) + + if is_dim(expr): + transformed, counter = transform_dimension(expr, counter, dimension_dict) + return transformed.arg(1), counter + + elif isinstance(expr, Prod): + dims = [] + for dim in expr.products: + assert is_dim(dim) + d, counter = transform_dimension(dim, counter, dimension_dict) + dims.append(d.arg(1)) + return z3.Product(dims), counter + + elif is_algebraic_expression(expr): + lhs, counter = transform_algebraic_expression( + expr.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + expr.rhs, counter, dimension_dict + ) + + if expr.op == op_sub: + c = lhs - rhs + + elif expr.op == op_add: + c = lhs + rhs + + elif expr.op == op_div: + c = lhs / rhs + + elif expr.op == op_mul: + c = lhs * rhs + + elif expr.op == op_mod: + c = lhs % rhs + + else: + raise NotImplementedError("operation not yet implemented") + + return c, counter + + else: + raise RuntimeError + + def transform_all_constraints(traced, counter=0): + """ + Given a trace, generates constraints and transforms them to z3 format + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(traced) + new_constraints, counter = generator.generate_constraints(counter) + + # print(new_constraints.conjucts[0]) + # print(*new_constraints.conjucts, sep='\n') + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + # print(new_constraints) + # print(new_constraints.conjucts) + # new_constraints.conjucts = new_constraints.conjucts[:-1] + # print(*new_constraints.conjucts, sep='\n') + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + # print(transformed) + return transformed + + def iterate_till_fixed_point(constraints, counter): + """ + Transform constraints till reaching a fixed point + """ + old_c = None + while old_c != constraints: + old_c = constraints + constraints, counter = transform_constraint(constraints, counter) + return constraints, counter + + def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): + """ + Takes a node and a graph and generates two sets of constraints. + One set constraints the node's constraints and another set + constraints the negation of the node's constraints + Args: + tracer_root: the root for getting the module instances + graph: the graph so far in the tracing process + node: node that represents a conditional + counter: variable tracking + + Returns: Two sets of constraints. One with a conjunction with the + the conditional constraint and the other with a conjunction with + its negation. + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(tracer_root, graph) + new_constraints, counter = generator.generate_constraints(counter) + + condition_constraint = new_constraints.conjucts[-1] + + # we know the constraint is a conjunction where the last constraint is about the conditional + # so remove the last constraint + new_constraints.conjucts = new_constraints.conjucts[:-1] + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + + # since the function returns a list of one element, we get the first element + # we are only interested in the RHS in this case because the LHS just stores + # the result + + # we make sure the constraint is of the form: + # c = b where b is a boolean expression + # and we consider b (constraint.rhs) for transformation + assert isinstance(condition_constraint.lhs, BVar) + assert is_bool_expr(condition_constraint.rhs) + condition_constraint_rhs = condition_constraint.rhs + + # transform the condition constraint + condition_constraint_rhs, counter = iterate_till_fixed_point( + condition_constraint_rhs, counter + ) + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + + transformed_condition_constraint, counter = transform_to_z3( + condition_constraint_rhs, counter, dimension_dict + ) + + negation_transformed_condition_constraint = z3.Not( + transformed_condition_constraint + ) + + return z3.And([transformed, transformed_condition_constraint]), z3.And( + [transformed, negation_transformed_condition_constraint] + ) + + def evaluate_conditional_with_constraints( + tracer_root, graph, node, counter=0, user_constraints=None + ): + """ + Given an IR and a node representing a conditional, evaluate the conditional + and its negation + Args: + tracer_root: Tracer root for module instances + node: The node to be evaluated + + Returns: the results of evaluating the condition and the negation with + the rest of the constraints + + """ + + ( + transformed_positive, + transformed_negative, + ) = transform_all_constraints_trace_time(tracer_root, graph, node, counter) + + s = z3.Solver() + s.add(transformed_positive) + if user_constraints is not None: + s.add(user_constraints) + condition = s.check() + + s = z3.Solver() + s.add(transformed_negative) + if user_constraints is not None: + s.add(user_constraints) + negation = s.check() + return condition, negation + +except ImportError: + HAS_Z3 = False diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py new file mode 100644 index 0000000000000000000000000000000000000000..bd40d2a463f5e78e3548df224ecd15e22813a3c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/util.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BVar, + DVar, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.operation import op_leq + + +def gen_tvar(curr): + """ + Generate a tensor variable + :param curr: The current counter + :return: a tensor variable and the updated counter + """ + curr += 1 + return TVar(curr), curr + + +def gen_dvar(curr): + """ + Generate a dimension variable + :param curr: the current counter + :return: a dimension variable and an updated counter + """ + curr += 1 + return DVar(curr), curr + + +def gen_bvar(curr): + """ + Generate a boolean variable + :param curr: the current counter + :return: a boolean variable and an updated counter + """ + curr += 1 + return BVar(curr), curr + + +def gen_tensor_dims(n, curr): + """ + Generate a list of tensor dimensions + :param n: the number of dimensions + :param curr: the current counter + :return: a list of dimension variables and an updated counter + """ + dims = [] + for _ in range(n): + dvar, curr = gen_dvar(curr) + dims.append(dvar) + return dims, curr + + +def gen_nat_constraints(list_of_dims): + """ + Generate natural number constraints for dimensions + """ + return [BinConstraintD(0, d, op_leq) for d in list_of_dims] diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py new file mode 100644 index 0000000000000000000000000000000000000000..939f4865ab7d982289303093db2024eda6603521 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -0,0 +1,30 @@ +try: + import z3 # type: ignore[import] + + HAS_Z3 = True + # dynamic type + dyn = z3.DeclareSort("Dyn") + dyn_type = z3.Const("dyn", dyn) + + # dimension + dim = z3.Datatype("dim") + dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort())) + dim = dim.create() + + # tensors + tensor_type = z3.Datatype("TensorType") + tensor_type.declare("Dyn", ("dyn", dyn)) + tensor_type.declare("tensor1", ("0", dim)) + tensor_type.declare("tensor2", ("0", dim), ("1", dim)) + tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim)) + tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim)) + tensor_type = tensor_type.create() + + # create dimension + D = dim.dim + + z3_dyn = tensor_type.Dyn(dyn_type) + + +except ImportError: + HAS_Z3 = False diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/normalize.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..73cce6017bf1b5cd944ebff1f26781ef26fb6638 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/normalize.py @@ -0,0 +1,163 @@ +# mypy: allow-untyped-defs +import operator +from typing import Any, Callable, Optional + +import torch +import torch.fx +import torch.fx as fx +from torch.fx import Proxy, Transformer +from torch.fx.node import Argument, map_aggregate, Node, Target +from torch.fx.operator_schemas import ( + create_type_hint, + normalize_function, + normalize_module, +) + +from .schema_type_annotation import AnnotateTypesWithSchema + + +class NormalizeArgs(Transformer): + """ + Normalize arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and rewritten to exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. Also populates default + values. Does not support positional-only parameters or varargs + parameters (*args, **kwargs). + + If the nodes have 'type' metadata, it will use it to disambiguate + overloads. Otherwise, it will throw an error. + + Example usage: + m = torchvision.models.resnet18() + traced = torch.fx.symbolic_trace(m) + traced = NormalizeArgs(traced).transform() + """ + + def __init__( + self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True + ): + super().__init__(module) + self.node_map: dict[Proxy, Node] = {} + self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs + + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + + def get_type(arg): + if isinstance(arg, fx.Node): + return n.meta["type"] if "type" in n.meta else None + return type(arg) + + arg_types = map_aggregate(n.args, get_type) + assert isinstance(arg_types, tuple) + arg_types = tuple([create_type_hint(i) for i in arg_types]) + kwarg_types = {k: get_type(v) for k, v in kwargs.items()} + if n.op == "call_function": + out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) + else: + out = super().run_node(n) + if n.op != "output": + self.node_map[out] = n + out.node.meta = n.meta + out.node.type = n.type + return out + + def call_function( + self, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Any], + arg_types: Optional[tuple[Any, ...]] = None, + kwarg_types: Optional[dict[str, Any]] = None, + ): + assert callable(target) + new_args_and_kwargs = normalize_function( + target, + args, # type: ignore[arg-type] + kwargs, + arg_types, # type: ignore[arg-type] + kwarg_types, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return self.tracer.create_proxy( + "call_function", target, new_args, new_kwargs + ) + else: + return super().call_function(target, args, kwargs) + + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + assert isinstance(target, str) + new_args_and_kwargs = normalize_module( + self.module, + target, + args, # type: ignore[arg-type] + kwargs, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return super().call_module(target, new_args, new_kwargs) + else: + return super().call_module(target, args, kwargs) + + +class NormalizeOperators(AnnotateTypesWithSchema): + """ + Normalize callsites that are different ways of "spelling" the same + invocation into a single, canonical call. Currently supports: + + 1. Normalize operators (e.g. operator.add) to the `torch` ops they + ultimately invoke (e.g. torch.add) when it is possible to statically + reason that + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = NormalizeOperators(traced).transform() + """ + + binary_magic_method_remap: dict[ + Callable[[Any, Any], Any], Callable[[Any, Any], Any] + ] = { + torch.add: operator.add, + torch.mul: operator.mul, + torch.sub: operator.sub, + torch.div: operator.truediv, + torch.floor_divide: operator.floordiv, + torch.remainder: operator.mod, + torch.eq: operator.eq, + torch.ne: operator.ne, + torch.lt: operator.lt, + torch.le: operator.le, + torch.gt: operator.gt, + torch.ge: operator.ge, + } + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + # Normalize operators according to the magic methods implemented on tensors here: + # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 + + assert callable(target) + + if target in self.binary_magic_method_remap: + if len(args) != 2: + return super().call_function(target, args, kwargs) + lhs, rhs = args + + return super().call_function( + target=self.binary_magic_method_remap[target], + args=(lhs, rhs), + kwargs={}, + ) + + return super().call_function(target, args, kwargs) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/optimization.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..3e406b57a96d571411ead68c404464c3bc10d63c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/optimization.py @@ -0,0 +1,486 @@ +# mypy: allow-untyped-defs +import copy +import logging +import operator +import time +from collections import defaultdict +from collections.abc import Iterable +from enum import Enum +from typing import Any, cast, Optional + +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.mkldnn as th_mkldnn +from torch.fx.node import Argument, Target +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval + + +__all__ = [ + "matches_module_pattern", + "replace_node_module", + "fuse", + "remove_dropout", + "extract_subgraph", + "modules_to_mkldnn", + "reset_modules", + "MklSubgraph", + "gen_mkl_autotuner", + "use_mkl_length", + "UnionFind", + "optimize_for_inference", +] + + +def _parent_name(target: str) -> tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name + + +# Works for length 2 patterns with 2 modules +def matches_module_pattern( + pattern: Iterable[type], node: fx.Node, modules: dict[str, Any] +): + if len(node.args) == 0: + return False + nodes: tuple[Any, fx.Node] = (node.args[0], node) + for expected_type, current_node in zip(pattern, nodes): + if not isinstance(current_node, fx.Node): + return False + if current_node.op != "call_module": + return False + if not isinstance(current_node.target, str): + return False + if current_node.target not in modules: + return False + if type(modules[current_node.target]) is not expected_type: + return False + return True + + +def replace_node_module( + node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module +): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + modules[node.target] = new_module + setattr(modules[parent_name], name, new_module) + + +def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: + """ + Fuses convolution/BN and linear/BN layers for inference purposes. + Will deepcopy your model by default, but can modify the model inplace as well. + """ + patterns = [ + (nn.Conv1d, nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d), + (nn.Linear, nn.BatchNorm1d), + ] + if not inplace: + model = copy.deepcopy(model) + if not no_trace or not isinstance(model, torch.fx.GraphModule): + fx_model = fx.symbolic_trace(model) + else: + fx_model = model + modules = dict(fx_model.named_modules()) + new_graph = copy.deepcopy(fx_model.graph) + + for pattern in patterns: + for node in new_graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: + # Output of conv/linear is used by other nodes + continue + first_layer = modules[node.args[0].target] + bn = modules[node.target] + if not bn.track_running_stats: + continue + if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]: + fused_layer = fuse_conv_bn_eval(first_layer, bn) + else: # nn.Linear + fused_layer = fuse_linear_bn_eval(first_layer, bn) + replace_node_module(node.args[0], modules, fused_layer) + node.replace_all_uses_with(node.args[0]) + new_graph.erase_node(node) + return fx.GraphModule(fx_model, new_graph) + + +def remove_dropout(model: nn.Module) -> nn.Module: + """ + Removes all dropout layers from the module. + """ + fx_model = fx.symbolic_trace(model) + + class DropoutRemover(torch.fx.Transformer): + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + if isinstance(self.submodules[target], nn.Dropout): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return DropoutRemover(fx_model).transform() + + +def extract_subgraph( + orig_module: nn.Module, + nodes: list[fx.Node], + inputs: list[fx.Node], + outputs: list[fx.Node], +): + """ + Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. + """ + new_graph = fx.Graph() + env: dict[fx.Node, fx.Node] = {} + for input in inputs: + new_node = new_graph.placeholder(input.name) + env[input] = new_node + for node in nodes: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + new_graph.output([env[output] for output in outputs]) + new_graph.lint() + return fx.GraphModule(orig_module, new_graph) + + +mkldnn_supported = [ + nn.Conv2d, + nn.Linear, + nn.BatchNorm2d, + nn.ReLU, + nn.MaxPool2d, + nn.AvgPool2d, + nn.AdaptiveAvgPool2d, + torch.relu, + torch.transpose, + torch.sigmoid, + F.relu, + F.avg_pool2d, + F.adaptive_avg_pool2d, +] +# These are operators that may not be convertible into MKLDNN ops (e.g. the +# args are scalar values). Thus, we only include them in the subgraph if their +# arguments are already in MKLDNN. +# TODO: Determine whether this can be removed after type inference. +mkldnn_supported_unknown = [operator.add, operator.mul] +mkldnn_map = { + nn.Conv2d: th_mkldnn.MkldnnConv2d, + nn.Linear: th_mkldnn.MkldnnLinear, + nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a), +} + + +def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]): + """ + For each node, if it's a module that can be preconverted into MKLDNN, + then we do so and create a mapping to allow us to convert from the MKLDNN + version of the module to the original. + """ + old_modules: dict[nn.Module, nn.Module] = {} + for node in nodes: + if node.op == "call_module": + assert isinstance(node.target, str) + cur_module = modules[node.target] + if type(cur_module) in mkldnn_map: + new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) + assert isinstance(new_module, nn.Module) + old_modules[new_module] = copy.deepcopy(cur_module) + replace_node_module(node, modules, new_module) + return old_modules + + +def reset_modules( + nodes: list[fx.Node], + modules: dict[str, nn.Module], + old_modules: dict[nn.Module, nn.Module], +): + """ + Maps each module that's been changed with `modules_to_mkldnn` back to its + original. + """ + for node in nodes: + if node.op == "call_module": + assert isinstance(node.target, str) + cur_module = modules[node.target] + if cur_module in old_modules: + replace_node_module(node, modules, old_modules[cur_module]) + + +class MklSubgraph: + def __init__(self, fx_graph: fx.Graph): + self.fx_graph = fx_graph + self.nodes: list[fx.Node] = [] + self.start_nodes: list[fx.Node] = [] + self.end_nodes: list[fx.Node] = [] + + +def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): + """ + This generates a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by running it with the example_inputs. + + Example usage: + heuristic = gen_mkl_autotuner(example_inputs, iters=10) + fast_model = optimization.optimize_for_inference(model, heuristic) + """ + fx_model = None + old_modules = None + + def use_mkl_heuristic(graph: MklSubgraph) -> bool: + nonlocal fx_model, old_modules + input_nodes = graph.start_nodes + if fx_model is None: + fx_model = graph.fx_graph.owning_module + old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] + ShapeProp(fx_model).propagate(example_inputs) + sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] + output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes]) + submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) + + def benchmark(f): + for _ in range(warmup): + f() + begin = time.time() + for _ in range(iters): + f() + return time.time() - begin + + mkl_time = benchmark( + lambda: [ + i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs]) + ] + ) + + reset_modules( + submodule.graph.nodes, dict(submodule.named_modules()), old_modules + ) + no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) + return mkl_time < no_mkl_time + + return use_mkl_heuristic + + +def use_mkl_length(graph: MklSubgraph) -> bool: + """ + This is a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by checking if there + are more than 2 nodes in it + """ + return len(graph.nodes) > 2 + + +class UnionFind: + def __init__(self, n): + self.parent: list[Optional[int]] = [None] * n + self.size: list[int] = [0] * n + + def make_set(self, v: int): + self.parent[v] = v + self.size[v] = 1 + + def find(self, v: int) -> int: + par = self.parent[v] + if v == par: + return v + assert par is not None + self.parent[v] = self.find(par) + return cast(int, self.parent[v]) + + def join(self, a: int, b: int): + a, b = self.find(a), self.find(b) + if a == b: + return a + if self.size[a] < self.size[b]: + a, b = b, a + self.parent[b] = a + self.size[a] += self.size[b] + + +def optimize_for_inference( + model: torch.nn.Module, + pass_config: Optional[dict[str, Any]] = None, + tracer: type[fx.Tracer] = fx.Tracer, +) -> torch.nn.Module: + """ + Performs a set of optimization passes to optimize a model for the + purposes of inference. Specifically, the passes that are run are: + 1. Conv/BN fusion + 2. Dropout removal + 3. MKL layout optimizations + + The third optimization takes a function `use_mkl_heuristic` that's used + to determine whether a subgraph should be explicitly run in MKL layout. + + Note: As FX does not currently handle aliasing, this pass currently + assumes nothing aliases. If that isn't true, use at your own risk. + """ + default_pass_config = { + "conv_bn_fuse": True, + "remove_dropout": True, + "mkldnn_layout_optimize": {"heuristic": use_mkl_length}, + } + if pass_config is None: + pass_config = {} + default_pass_config.update(pass_config) + + if default_pass_config["conv_bn_fuse"]: + model = fuse(model) + if default_pass_config["remove_dropout"]: + model = remove_dropout(model) + if default_pass_config["mkldnn_layout_optimize"] is False: + return model + if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): + raise RuntimeError("mkldnn_layout_optimize config is not a dict") + if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: + raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") + use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] + + cur_tracer = tracer() + fx_graph = cur_tracer.trace(copy.deepcopy(model)) + fx.GraphModule(cur_tracer.root, fx_graph) + modules: dict[str, nn.Module] = dict(model.named_modules()) + + class MklSupport(Enum): + NO = 1 + YES = 2 + UNKNOWN = 3 + + # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. + # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. + # However, if it's in `mkldnn_supported_unknown`, then we only treat it as + # a MKLDNN node if its inputs are MKLDNN nodes. + for node in list(fx_graph.nodes): + supports_mkldnn = MklSupport.NO + if node.op == "call_module": + cur_module = modules[node.target] + if type(cur_module) in mkldnn_supported: + supports_mkldnn = MklSupport.YES + sample_parameter = next(cur_module.parameters(), None) + if sample_parameter is not None: + assert sample_parameter.dtype == torch.float, ( + "this pass is only for torch.float modules" + ) + assert sample_parameter.device == torch.device("cpu"), ( + "this pass is only for CPU modules" + ) + elif node.op == "call_function": + if node.target in mkldnn_supported: + supports_mkldnn = MklSupport.YES + elif node.target in mkldnn_supported_unknown: + supports_mkldnn = MklSupport.UNKNOWN + + if supports_mkldnn != MklSupport.NO: + if supports_mkldnn == MklSupport.UNKNOWN: + if not any(arg.target == "to_dense" for arg in node.args): + continue + with fx_graph.inserting_before(node): + mkldnn_args = fx.map_arg( + node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,)) + ) + + node.args = cast(tuple[fx.node.Argument], mkldnn_args) + + with fx_graph.inserting_after(node): + dense_x = fx_graph.create_node("call_method", "to_dense", (node,)) + node.replace_all_uses_with(dense_x) + dense_x.args = (node,) + + # Does pre-conversion of all modules into MKLDNN (when possible) + old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) + fx_graph.old_modules = old_modules # type: ignore[attr-defined] + + # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b + for node in fx_graph.nodes: + if node.op == "call_method" and node.target == "to_dense": + prv_node = node.args[0] + users = list(node.users) + for user in users: + if user.op == "call_method" and user.target == "to_mkldnn": + user.replace_all_uses_with(prv_node) + fx_graph.erase_node(user) + if len(node.users) == 0: + fx_graph.erase_node(node) + + num_nodes = len(fx_graph.nodes) + uf = UnionFind(num_nodes) + + def get_color(n): + if hasattr(n, "color"): # Current node is part of a MKL subgraph + return uf.find(n.color) + if hasattr(n, "start_color"): # Current node is input to MKL subgraph + return uf.find(n.start_color) + return None + + # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists + # of input nodes (which are only `to_mkldnn` calls), output nodes + # (`to_dense` calls), and intermediate nodes, which are run entirely on + # MKLDNN layout tensors. + # + # Specifically, this code does a flood fill on a directed acyclic graph + # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). + # If every node only had one input, this would be sufficient. However, in + # the case that a node has multiple inputs coming from different start + # nodes (i.e. colors), we need to join these 2 colors into 1. That's done + # using a Disjoint Set Union. + for cur_idx, node in enumerate(fx_graph.nodes): + if node.op == "call_method" and node.target == "to_mkldnn": + node.start_color = cur_idx + uf.make_set(cur_idx) + elif node.op == "call_method" and node.target == "to_dense": + assert get_color(node.args[0]) is not None + node.end_color = get_color(node.args[0]) + else: + cur_colors = [ + get_color(i) + for i in node.all_input_nodes + if isinstance(i, fx.Node) + if get_color(i) is not None + ] + + if len(cur_colors) == 0: + continue + assert not any(i is None for i in cur_colors) + cur_colors = sorted(cur_colors) + node.color = cur_colors[0] + for other_color in cur_colors[1:]: + uf.join(cur_colors[0], other_color) + + mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) + for node in fx_graph.nodes: + if hasattr(node, "color"): + mkldnn_graphs[uf.find(node.color)].nodes.append(node) + if hasattr(node, "start_color"): + mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) + if hasattr(node, "end_color"): + mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) + + # Now that we have all the subgraphs, we need to decide which MKLDNN + # subgraphs we actually want to keep in MKLDNN. + for graph in mkldnn_graphs.values(): + if not use_mkl_heuristic(graph): + for node in graph.start_nodes + graph.end_nodes: + prv = node.args[0] + node.replace_all_uses_with(prv) # type: ignore[arg-type] + fx_graph.erase_node(node) + reset_modules(graph.nodes, modules, old_modules) + + mkldnn_conversions = 0 + for node in fx_graph.nodes: + if node.target == "to_mkldnn" or node.target == "to_dense": + mkldnn_conversions += 1 + + logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions) + fx_graph.lint() + result = fx.GraphModule(model, fx_graph) + return result diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3658dd1a9ce96aff26adbc5f47818e9e57e13d35 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/partitioner_utils.py @@ -0,0 +1,317 @@ +# mypy: allow-untyped-defs +from enum import Enum +from typing import NamedTuple + +from torch.fx.node import map_arg, Node + + +class Partition: + """Partition class contains all the information about an individual partition. + It also provides necessary methods for manipulation the partition. + """ + + def __init__(self, partition_id: int) -> None: + self.nodes: set[Node] = set() + self.partition_id = partition_id + self.parents: set[Partition] = set() + self.children: set[Partition] = set() + self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: list[int] = [] + + def __str__(self): + return str(self.partition_id) + + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) + + def add_node(self, node): + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {"placeholder", "get_attr"}: + self.nodes.add(n) + self.nodes.add(node) + self.recalculate_mem_size() + + def remove_node(self, node): + # Remove a node only if the node is in the partition + if node in self.nodes: + self.nodes.remove(node) + # Collect the node's input nodes + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Check if an input node is a placeholder or get_attr, + # and this input node is not used by some other nodes in this partition, + # the remove this input node + for input_node in input_nodes: + if all( + n not in self.nodes for n in input_node.users + ) and input_node.op in {"placeholder", "get_attr"}: + self.nodes.remove(input_node) + self.recalculate_mem_size() + + +class Device(NamedTuple): + name: str + available_mem_bytes: int + logical_id: int + + +class NodeLatency(NamedTuple): + # Latency due to the memory bandwidth + mem_latency_sec: float + # Latency due to the computation + computer_latency_sec: float + + +class PartitionLatency(NamedTuple): + # Sum of all nodes' memory latency on the critical path + mem_latency_sec: float + # Sum of all nodes' compute latency on the critical path + computer_latency_sec: float + # Latency of the critical path + overall_latency_sec: float + + +class PartitionMode(Enum): + size_based = 0 + sparse_nn = 1 + cost_aware = 2 + kl_based = 3 + aot_based = 4 + + +class PartitionerConfig(NamedTuple): + devices: list[Device] + mode: PartitionMode = PartitionMode.size_based + transfer_rate_bytes_per_sec: float = 0.0 + node_to_latency_mapping: dict[Node, NodeLatency] = {} + node_to_partition_mapping: dict[Node, int] = {} + partition_to_logical_device_mapping: dict[int, list[int]] = {} + # Saturate host by replicating partitions to the remaining idle devices. + saturate_host: bool = False + + +def get_extra_size_of(node: Node, nodes: set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError("node has no size_bytes attr") + # Don't forget the op node itself + size_bytes = getattr(node, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError("node has no size_bytes attr") + return total_size_of_input_nodes + + +def get_latency_of_one_partition( + partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency] +) -> PartitionLatency: + """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" + + def get_top_nodes(partition: Partition) -> list[Node]: + """Given a partition, return a list of nodes on the top bfs level""" + top_nodes: list[Node] = [] + for node in partition.nodes: + # Skip placeholder and get_attr nodes + if node.op in {"placeholder", "get_attr"}: + continue + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # If a node has no input nodes in this partition, + # or its input nodes in this partition are placeholders and get_attrs + # this node is on the top bfs level in this partition + if not any( + n in partition.nodes and n.op not in {"placeholder", "get_attr"} + for n in input_nodes + ): + top_nodes.append(node) + return top_nodes + + def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + """Given a top node of a partition, this function returns + the latency of the critical path in the partition + """ + node_latency = node_to_latency_mapping[node] + # Calculate the current overall latency of the partition + overall_latency_sec = partition_latency.overall_latency_sec + max( + node_latency.computer_latency_sec, node_latency.mem_latency_sec + ) + # Update the mem latency of this path + mem_latency_sec = ( + partition_latency.mem_latency_sec + node_latency.mem_latency_sec + ) + # Update the compute latency of this path + computer_latency_sec = ( + partition_latency.computer_latency_sec + node_latency.computer_latency_sec + ) + # Get all users of this node that are in this partition + users = set(node.users).intersection(partition.nodes) + if users: + max_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + for n in users: + # Get new partition latency recursively + new_partition_latency = dfs_helper( + n, + PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ), + ) + if ( + new_partition_latency.overall_latency_sec + > max_latency.overall_latency_sec + ): + max_latency = new_partition_latency + return max_latency + # If there is no user, the node is at bottom of the partition + return PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ) + + # Main part starts + # Get all top level nodes of this partition + top_nodes = get_top_nodes(partition) + critical_path_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + # Go through all top nodes and find the largest latency (critical pass latency) + for node in top_nodes: + partition_latency = dfs_helper( + node, + PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ), + ) + if ( + partition_latency.overall_latency_sec + > critical_path_latency.overall_latency_sec + ): + critical_path_latency = partition_latency + return critical_path_latency + + +def get_partition_to_latency_mapping( + partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency] +) -> dict[Partition, PartitionLatency]: + """Given all the partitions and node_to_latency_mapping dictionary, + return a mapping dictionary of each partition to its overall latency + """ + partition_to_latency_mapping: dict[Partition, PartitionLatency] = {} + # Go through each partition and get its latency + for partition in partitions: + partition_latency = get_latency_of_one_partition( + partition, node_to_latency_mapping + ) + partition_to_latency_mapping[partition] = partition_latency + return partition_to_latency_mapping + + +def get_comm_latency_between( + parent_partition: Partition, + child_partition: Partition, + transfer_rate_bytes_per_sec: float, +): + """Given two partitions (parent and child), + calculate the communication latency between the two. + """ + # If two partitions are on the same device, the comm latency is 0. + if ( + parent_partition.logical_device_ids != [] + and child_partition.logical_device_ids != [] + and parent_partition.logical_device_ids == child_partition.logical_device_ids + ): + return 0.0 + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + for node in child_partition.nodes: + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + for n in input_nodes: + if n in parent_partition.nodes and n not in visited_nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes is not None: + comm_size += size_bytes.output_size + visited_nodes.add(n) + return comm_size / transfer_rate_bytes_per_sec + + +def get_latency_of_partitioned_graph( + partitions: list[Partition], + partition_to_latency_mapping: dict[Partition, PartitionLatency], + transfer_rate_bytes_per_sec: float, +): + """Given all partitions in a graph, find the critical path among all partitions + and return its latency as the latency of the whole graph + """ + + def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: + """This function helps to recursively get the latency of a path of partitions""" + # Update latency by adding current partition's latency + latency_so_far_sec += partition_to_latency_mapping[ + partition + ].overall_latency_sec + + if partition.children: + max_latency_sec = 0.0 + for child in partition.children: + # Calculate latency between + comm_latency_sec = get_comm_latency_between( + partition, child, transfer_rate_bytes_per_sec + ) + new_latency_sec = dfs_helper( + child, latency_so_far_sec + comm_latency_sec + ) + if new_latency_sec > max_latency_sec: + max_latency_sec = new_latency_sec + return max_latency_sec + return latency_so_far_sec + + def get_top_partitions(partitions: list[Partition]) -> list[Partition]: + """This function is to return all the partitions without parents + as the starting points of all the paths + """ + # If a partition has no parents, then it is a top partition + top_partitions = [ + partition for partition in partitions if len(partition.parents) == 0 + ] + return top_partitions + + top_partitions = get_top_partitions(partitions) + critical_path_latency_sec = 0.0 + for partition in top_partitions: + latency_sec = dfs_helper(partition, 0.0) + if latency_sec > critical_path_latency_sec: + critical_path_latency_sec = latency_sec + return critical_path_latency_sec diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d8e58d952ed0b2d7de99d33d196ea18ea0291a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py @@ -0,0 +1,2434 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import functools +import inspect +import logging +import operator +import traceback +import typing +import typing_extensions +import weakref +from collections import defaultdict, OrderedDict +from collections.abc import Generator, Mapping, Sequence +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Optional, + overload, + Protocol, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import Concatenate, ParamSpec, Self, TypeVarTuple, Unpack +from weakref import WeakKeyDictionary + +import torch +import torch._ops +import torch.fx as fx +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import SymBool, SymInt, Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import trace_structured +from torch._subclasses.fake_impls import fast_detach +from torch._subclasses.fake_tensor import ( + FakeTensor, + FakeTensorMode, + is_fake, + unset_fake_temporarily, +) +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx import GraphModule, Proxy, Tracer +from torch.fx.graph_module import _assign_attr +from torch.fx.node import ( + _side_effectful_need_to_be_preserved_pre_dispatch, + Argument, + Target, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.nn import Module +from torch.overrides import TorchFunctionMode +from torch.utils._python_dispatch import ( + _disable_infra_mode, + _push_mode, + _unset_infra_mode, + TorchDispatchMode, +) +from torch.utils._stats import count +from torch.utils._thunk import Thunk +from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary + +from ._backward_state import BackwardState +from .sym_node import SymNode + + +if TYPE_CHECKING: + import types + from collections.abc import MutableMapping + + import sympy + + from torch._ops import OpOverload + from torch.fx._symbolic_trace import PHBase + from torch.types import IntLikeType + +__all__ = [ + "PythonKeyTracer", + "dispatch_trace", + "make_fx", + "DecompositionInterpreter", + "py_sym_types", + "get_innermost_proxy_mode", + "get_proxy_mode", + "handle_sym_dispatch", + "maybe_enable_thunkify", + "maybe_disable_thunkify", +] + +_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] + +_AnyScriptObject = (torch.ScriptObject, FakeScriptObject) +_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject] + +aten = torch.ops.aten +prim = torch.ops.prim + +log = logging.getLogger(__name__) +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + +CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {} + +CONSTANT_NUMEL_LIMIT = 1 + +T = TypeVar("T") +U = TypeVar("U") +_P = ParamSpec("_P") +R = TypeVar("R") +_Ts = TypeVarTuple("_Ts") + +null_ctx_type = type(nullcontext) +# We currently convert all SymInt to proxies before we use them. +# This could plausibly be handled at the Dynamo level. +pytree.register_pytree_node( + torch.Size, + lambda xs: (list(xs), None), + lambda xs, _: tuple(xs), + flatten_with_keys_fn=lambda xs: ( + [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], + None, + ), + serialized_type_name="torch.Size", +) +# Ideally unflattening should not lose info, but we unflatten +# torch.Size to tuple (see above). This is necessary because the +# torch.Size constructor only accepts ints whereas our infra often +# transforms them to non-ints, e.g. symint proxies. Anyway, losing +# such info can cause pytree mapping or spec matching to fail, so +# work around this problem using the following dict as needed. +_pytree_subclasses_that_lose_info = {torch.Size: tuple} + + +def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]: + """FX gets confused by varargs, de-confuse it""" + argnames = ",".join(f"arg{i}" for i in range(nargs)) + return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) + + +@contextmanager +def decompose( + decomposition_table: Optional[Mapping[OpOverload, Callable]], +) -> Generator[Mapping[OpOverload, Callable], None, None]: + global CURRENT_DECOMPOSITION_TABLE + old_decomposition_table = CURRENT_DECOMPOSITION_TABLE + CURRENT_DECOMPOSITION_TABLE = decomposition_table or {} + try: + yield CURRENT_DECOMPOSITION_TABLE + finally: + CURRENT_DECOMPOSITION_TABLE = old_decomposition_table + + +# ensure we cannot collide with other properties +proxy_slot = object() + + +class _NoDefault: + pass + + +no_default = _NoDefault() + +from torch.types import py_sym_types, PySymType + + +class _HasMeta(Protocol): + meta: dict[str, PySymType] + + +def is_sym_node(node: _HasMeta) -> bool: + assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta" + return "val" in node.meta and isinstance(node.meta["val"], py_sym_types) + + +@overload +def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ... + + +@overload +def set_proxy_slot( + obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy +) -> None: ... + + +@overload +def set_proxy_slot( + obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType +) -> None: ... + + +def set_proxy_slot( + obj: Union[PySymType, _AnyScriptObjectType, Tensor], + tracer: _ProxyTracer, + proxy: object, +) -> None: + log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy) + if isinstance(obj, Tensor): + # We DO want to clobber proxies whenever we run an inplace operation + # on a tensor, and it affects the metadata on the proxy. + assert isinstance(proxy, _ProxyTensor) + tracer.tensor_tracker[obj] = proxy + elif isinstance(obj, (_AnyScriptObject)): + # We DO want to clobber proxies, with a similar rationale as for tensors. + assert isinstance(proxy, Proxy) + tracer.script_object_tracker[obj] = proxy + else: + # NB: Never clobber pre-existing proxy. Although the proxies + # are in principle equivalent, when we do graph partitioning + # we need there not to be spurious dependencies on tangent inputs. + # This works because primals get their SymInts set first, and + # THEN later we allocate tangent inputs. Make sure if a SymInt + # is derivable from a primal that we use that. + assert isinstance(obj, py_sym_types), type(obj) + if obj not in tracer.symnode_tracker: + tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy) + + # WAR: python test/dynamo/test_subclasses.py + # TestNestedTensor.test_basic_autograd + # + # AOTAutograd doesn't pass the "outer sizes" as an actual argument + # to make_fx, but it is made use of internally in AOTAutograd's + # call to tensor unflatten. Because the outer sizes isn't passed + # as an argument, it is therefore untracked. However, it turns + # out you luck out, because *Dynamo* will manually add the outer + # sizes as an argument so you can fix up the proxy'ness. + # + # This is probably fixed in + # https://github.com/pytorch/pytorch/pull/125941/ + import sympy + + if isinstance(obj.node.expr, sympy.Symbol): + tracer.sympy_expr_tracker[obj.node.expr] = proxy + + +def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: + assert isinstance(obj, (Tensor, SymNode)), type(obj) + return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) + + +_PySymProxyType = Thunk[Proxy] + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, +) -> _ProxyTensor: ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, +) -> Union[_ProxyTensor, U]: ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_ProxyTensor], R], +) -> Union[R, U]: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, +) -> Proxy: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, +) -> Union[Proxy, U]: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[Proxy], R], +) -> Union[R, U]: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, +) -> _PySymProxyType: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: T, +) -> Union[T, _PySymProxyType]: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_PySymProxyType], R], +) -> Union[R, U]: ... + + +# the default argument is what to return if the slot is not set. +# the transform argument is handy if you need to extract a subfield from +# the successfully looked up result (but NOT the default.) +def get_proxy_slot( + obj: Union[Tensor, _AnyScriptObjectType, PySymType], + tracer: _ProxyTracer, + default: object = no_default, + transform: Callable = lambda x: x, +) -> object: + tracker: Any + if isinstance(obj, Tensor): + tracker = tracer.tensor_tracker + elif isinstance(obj, _AnyScriptObject): + tracker = tracer.script_object_tracker + else: + assert isinstance(obj, py_sym_types), type(obj) + tracker = tracer.symnode_tracker + + if obj not in tracker: + # Last ditch + if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker: + value = tracer.sympy_expr_tracker[obj.node.expr] + else: + if isinstance(default, _NoDefault): + raise RuntimeError( + f"{obj} ({id(obj)})is not tracked with proxy for {tracer}" + ) + return default + else: + value = tracker[obj] + res = transform(value) + return res + + +def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]: + # val.detach() will also eventually call fast_detach(), + # but this saves us a full trip into __torch_dispatch__ + # (snapshot_fake is called a lot) + if isinstance(val, FakeTensor): + return fast_detach(val.fake_mode, val, include_real) + else: + return val.detach() + + +_ExtractValType = Optional[ + Union[ + PySymType, + _AnyScriptObjectType, + BackwardState, + list["_ExtractValType"], + tuple["_ExtractValType", ...], + dict[str, "_ExtractValType"], + Tensor, + int, + float, + bool, + ] +] + + +def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType: + if is_fake(val): + return snapshot_fake(val, include_real=include_real) + elif isinstance(val, py_sym_types): + return val + elif isinstance(val, _AnyScriptObject): + return val + elif isinstance(val, BackwardState): + return val + elif isinstance(val, (list, tuple)): + return val.__class__([extract_val(x) for x in val]) + elif isinstance(val, dict): + return {k: extract_val(v) for k, v in val.items()} + elif isinstance(val, Tensor): + if not val.is_sparse: + # NB: Kinda hacky, but we should try to get val as the metadata + # everywhere + # TODO: This doesn't properly track storages. A more robust + # approach would be to maintain a per-trace FakeTensorMode and + # from_real_tensor to create fake values (don't forget to + # snapshot_fake) + from torch._guards import detect_fake_mode + + fake_tensor_mode = detect_fake_mode(val) + if not fake_tensor_mode: + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) + with fake_tensor_mode: + return torch.empty_strided( + val.shape, val.stride(), device=val.device, dtype=val.dtype + ) + else: + return None + elif isinstance(val, (int, float, bool)): + return val + elif val is None: + return None + + typing_extensions.assert_never(val) + + +@contextmanager +def _enable_thunkify( + tracer: _ProxyTracer, *, enable: bool = True +) -> Generator[None, None, None]: + """ + Enable thunkification inside the context manager. Thunkification prevents + SymNode computation from directly being traced into an FX graph; instead, + the compute is only added to the graph if it is actually used. This helps + us track SymNode compute when it is computed (since we need /something/ + to put in the tracker) even if it is unlikely to be used. + """ + old = tracer.enable_thunkify + tracer.enable_thunkify = enable + try: + yield + finally: + tracer.enable_thunkify = old + + +@contextmanager +def maybe_disable_thunkify() -> Generator[None, None, None]: + """Within a context, disable thunkification. See :func:`maybe_enable_thunkify` + for more details. This is helpful if you have a wrapper function which + you want to enable thunkification on, but in some segment on the inside (say, + the original user function), you want to disable thunkification as you know + it is not needed there. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer, enable=False): + yield + else: + yield + + +@contextmanager +def maybe_enable_thunkify() -> Generator[None, None, None]: + """Within this context manager, if you are doing make_fx tracing, we will thunkify + all SymNode compute and avoid tracing it into the graph unless it is actually needed. + You should prefer to avoid using this as much as possible, as lazy evaluation of + SymNode tracing can lead to long chains of thunks which will stack overflow + if you evaluate them. However, this is currently sometimes necessary as there + are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error + due to insufficient tracing of SymNode computation. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer): + yield + else: + yield + + +# Note [invariants for node meta 'val'] +# What invariants do we have for the 'val' set on the FX node? It has accurate +# metadata... but only for metadata that exists "below" all other subsystems +# (most notably autograd, but also vmap, functorch transforms, etc). This means +# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad, +# grad_fn, _base (_base actually may be set due to recursive call to +# ADInplaceOrView, but you shouldn't rely on it.) +def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy: + proxy.node.meta["val"] = extract_val( + val, include_real=(proxy.node.op == "placeholder") + ) + + with _enable_thunkify(proxy.tracer): # type: ignore[arg-type] + # Best effort tensor_meta setting; prefer using val! + if is_fake(val): + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + elif isinstance(val, Tensor) and not val.is_sparse: + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + return proxy + + +def thunkify( + tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs +) -> Thunk[R]: + """ + Delays computation of f until it's called again + Also caches the result + """ + if tracer.enable_thunkify: + return Thunk(functools.partial(f, *args, **kwargs)) + else: + r = f(*args, **kwargs) + return Thunk(lambda: r) + + +def track_tensor( + tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer +) -> None: + def try_set_proxy_slot( + outer_s: IntLikeType, + proxy_callable: Callable[Concatenate[PySymType, _P], Proxy], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> None: + assert callable(proxy_callable) + if isinstance(outer_s, SymInt): + with _enable_thunkify(tracer): + set_proxy_slot( + outer_s, + tracer, + thunkify(tracer, proxy_callable, outer_s, *args, **kwargs), + ) + + # The basic idea is that we need to associate each tensor/SymInt + # with a Proxy. How do we setup this association? We just store + # the proxy on the proxy slot of the object, keyed on the tracer + # (so that if we have multiple tracers at the same time, they + # don't clobber each other.) + for i, s in enumerate(tensor.shape): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_size.int, (proxy, i), {} + ), + x, + ), + i, + ) + + if not is_sparse_any(tensor): + for i, s in enumerate(tensor.stride()): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_stride.int, (proxy, i), {} + ), + x, + ), + i, + ) + + try_set_proxy_slot( + tensor.numel(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_numel.default, (proxy,), {} + ), + x, + ), + ) + if not is_sparse_any(tensor): + try_set_proxy_slot( + tensor.storage_offset(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", + torch.ops.aten.sym_storage_offset.default, + (proxy,), + {}, + ), + x, + ), + ) + set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) + + +_NestedProxys = Union[ + Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"] +] +_NestedTensors = Union[ + Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"] +] + + +def track_tensor_tree( + inner_res: T, + proxy_res: _NestedProxys, + *, + constant: Optional[_NestedTensors], + tracer: _ProxyTracer, +) -> T: + # NB: We call set_unbacked_bindings only on the *topmost* call to + # track_tensor_tree, not recursive calls. This is because there must + # be only ONE unbacked_binding proxy call, and it should be the one + # where all of the unbacked SymInts actually first come into existence. + # If you call this again on the inner proxies for the tuple projections, + # you will have multiple unbacked_bindings for the same symbol, but + # they're not going to show up anywhere. + # + # I was briefly deceived into setting unbacked bindings recursively when + # working on https://github.com/pytorch/pytorch/pull/133585 because I + # observed that some extra unbacked bindings were needed to handle some + # higher order operator code. But actually it looks like this was + # just an unrelated bug that needed to be fixed separately. + _set_unbacked_bindings(inner_res, proxy_res) + + def wrap_with_proxy( + e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] + ) -> None: + if isinstance(e, Tensor): + assert isinstance(proxy, Proxy) + assert constant is None or isinstance(constant, Tensor) + track_tensor(e, proxy, tracer=tracer, constant=constant) + set_meta(proxy, e) + elif isinstance(e, py_sym_types): + assert isinstance(proxy, Proxy) + # NB: eagerly set meta here, so that the numbering is in order + set_meta(proxy, e) + set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) + elif isinstance(e, _AnyScriptObject): + assert isinstance(proxy, Proxy) + set_proxy_slot(e, tracer, proxy) + set_meta(proxy, e) + elif isinstance(e, (tuple, list)): + # example use case: allreduce_ returns ([tensor], work) + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + def get_constant( + c: Optional[_NestedTensors], idx: int + ) -> Optional[_NestedTensors]: + if c is None: + return None + else: + assert isinstance(c, (list, tuple)) + return c[idx] + + for idx, ee in enumerate(e): + # Use an indexer here - if proxy is a List then it will unwrap + # it. If it's a Proxy then it will proxy the getelem. + wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx)) # type: ignore[index] + + elif isinstance(e, dict): + # example use case: triton_kernel_wrapper takes arguments as kwargs + + # In theory we could support const-prop when proxy-tensor-tracing + # operators that returns dicts of tensors, but we have no use case + # for it today (since the only op we currently trace that can + # return a dict is triton_kernel_wrapper_functional/mutation, + # which does not participate in const-prop) + assert constant is None + + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + for key, val in e.items(): + wrap_with_proxy(val, proxy[key], None) # type: ignore[index] + + elif isinstance(e, BackwardState): + assert isinstance(proxy, Proxy) + set_meta(proxy, e) + e.proxy = proxy + else: + # intentionally pass on primitives + pass + + wrap_with_proxy(inner_res, proxy_res, constant) + + return inner_res + + +@dataclass +class _ProxyTensor: + proxy: Proxy + constant: Optional[Tensor] + + +def fetch_sym_proxy( + tracer: _ProxyTracer, +) -> Callable[[PySymType], Union[bool, int, float, Proxy]]: + def inner(e: PySymType) -> Union[int, bool, float, Proxy]: + n = e.node + if n.constant is not None: + return n.constant + if e.node.expr.is_number: + if isinstance(e, SymBool): + return bool(e.node.expr) + elif isinstance(e, SymInt): + return int(e.node.expr) + return float(e.node.expr) + else: + assert isinstance(e, py_sym_types) + # NB: we REQUIRE all symints to be tracked + return get_proxy_slot(e, tracer).force() + + return inner + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: Tensor +) -> Union[_ProxyTensor, Tensor]: ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: _AnyScriptObjectType +) -> Union[Proxy, _AnyScriptObjectType]: ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: PySymType +) -> Union[_PySymProxyType, PySymType]: ... + + +def fetch_object_proxy( + tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] +) -> object: + return get_proxy_slot(t, tracer, t) + + +HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor) + + +def _maybe_record_pointwise_barrier( + func: object, proxy_mode: ProxyTorchDispatchMode +) -> None: + """ + Records pointwise operators in user program (non decomposed) that were output in fp16/bf16 + """ + if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts: + return + + if ( + not isinstance(func, torch._ops.OpOverload) + or torch.Tag.pointwise not in func.tags + ): + return + + last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes))) + t = last_node.meta.get("val") + if not isinstance(t, torch.Tensor) or t.dtype not in ( + torch.bfloat16, + torch.float16, + ): + return + + last_node.meta["low_precision_pointwise_barrier"] = True + + +def proxy_call( + proxy_mode: ProxyTorchDispatchMode, + func: OpOverload, + pre_dispatch: bool, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + unrecognized_types: list[type] = [] + flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs)) + + def can_handle_tensor(x: Tensor) -> bool: + r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) + if proxy_mode._allow_fake_constant: + r = r or type(x) in (torch._subclasses.FakeTensor,) + if not r: + unrecognized_types.append(type(x)) + return r + + # If there are any tensor subclasses, we need to handle those tensor subclasses first + # TODO: we could use types to test this + if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)): + not_implemented_log.debug( + "ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", + unrecognized_types, + ) + return NotImplemented + + r = maybe_handle_decomp(proxy_mode, func, args, kwargs) + if r is not NotImplemented: + _maybe_record_pointwise_barrier(func, proxy_mode) + return r + + # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. + if not pre_dispatch and func not in [ + torch.ops.aten.size.default, + torch.ops.aten.stride.default, + torch.ops.aten.storage_offset.default, + ]: + with proxy_mode: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + if func is torch.ops.aten.is_nonzero.default: + with proxy_mode: + torch._check( + args[0].numel() == 1, # type: ignore[attr-defined] + lambda: "Boolean value of Tensor with more than one value is ambiguous", + ) + return (args[0] != 0).item() # type: ignore[attr-defined] + + tracer = proxy_mode.tracer + f_flat_args_kwargs = [ + ( + fetch_object_proxy(tracer, x) + if isinstance(x, (Tensor, _AnyScriptObject)) + else x + ) + for x in flat_args_kwargs + ] + + # If there are SymInts, we also should not consider this constant. + # However, fake tensor handling of SymInts is sufficiently broken that + # I couldn't write a test for this case + all_constant = ( + not any( + t.constant is None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + # TODO: maybe constant SymInts should also be allowed? Not sure if + # this can happen + and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs) + ) + + if torch.Tag.data_dependent_output in func.tags: + # Check if all of the Tensor inputs are constants + if all_constant: + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + with unset_fake_temporarily(): + return func(*const_args, **const_kwargs) + # If any of the Tensor inputs are "real" (not FakeTensor), we may + # incorrectly burn in constants by allowing this access. Raise + # an error in this case + if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only( + Tensor, lambda t: not is_fake(t), (args, kwargs) + ): + raise RuntimeError( + f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " + "It's likely that this is caused by data-dependent control flow or similar. " + "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' " + "in your make_fx call." + ) + + proxy_flat_args_kwargs = [ + e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs + ] + proxy_flat_args_kwargs = [ + (fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e) + for e in proxy_flat_args_kwargs + ] + proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec) + + # When we trace through a torch.tensor invocation, you never actually + # see a torch.ops.aten.tensor call. Instead, the way this function is + # implemented internally is that we allocate a plain tensor (this is + # *guaranteed* to be a plain tensor, we disable all modes when doing + # so), and then call at::lift_fresh on it (to give modes a chance to do + # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed + # to be freshly allocated, so we want lift_fresh to be a no-op (directly + # returning the input argument). + # + # Here is the basic problem: when we trace this sequence of executions + # into an FX graph, what happens to this call sequence? Traditionally, + # tensor constants get interned as buffers on the FX GraphModule. But + # this is dangerous. Consider: + # + # x = torch.tensor(1) + # x.add_(2) + # + # Naively, this traces into: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh(t) + # x.add_(2) + # + # If lift_fresh returns t directly, the subsequent add_ call will + # modify the tensor constant. Really, the problem is we've violated + # the invariant the argument to lift is fresh. So what we should + # preserve the invariant by replacing lift_fresh with lift_fresh_copy: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh_copy(t) + # x.add_(2) + # + # This is what the overload modification does. + if func is torch.ops.aten.lift_fresh.default: + func = torch.ops.aten.lift_fresh_copy.default + + proxy_out = proxy_mode.tracer.create_proxy( + "call_function", + func, + proxy_args, + proxy_kwargs, + name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__), + ) + + with _enable_thunkify(proxy_mode.tracer): + out = func(*args, **kwargs) + + # In some circumstances, we will be tracing in a situation where a tensor + # is *statically* known to be a constant (currently, this only happens if + # you run torch.tensor; deterministic factory functions like torch.arange + # don't get this treatment). When the tensor in question is small, it's + # helpful to due constant propagation in case we call item() (in which + # case we can return the constant value that is known, rather than give + # an error.) The logic here tests if constant propagation is possible + # (because all of the inputs are constant). If so, we disable fake tensor + # mode (if it is on) and do true compute on the constant. + # + # It's worth highlighting that we're making a policy decision here. + # There is a potential that the tensor is actually quite large, and we + # don't actually want to run the compute. The tensor being quite large + # is one of the reasons why factory functions don't get this treatment + # (since they can be quite large; if a parameter is initialized to a + # constant value it will be!) Similarly, there is also a potential + # to run an operator that blows up the size of a small tensor; we don't + # protect against this case, but we could force, e.g., only single + # element constant computation by testing the numel of the result before + # propagating const-ness. Similarly, we don't require the constant to + # live on CPU, but we could. + any_constant = any( + t.constant is not None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + + constant = None + + def tensor_numel_in_limit(t: Tensor) -> bool: + return t.numel() <= CONSTANT_NUMEL_LIMIT + + # If this is a lift, the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point + if ( + func is torch.ops.aten.lift_fresh_copy.default + and out.numel() <= CONSTANT_NUMEL_LIMIT + ): + with unset_fake_temporarily(): + assert isinstance(args[0], (Proxy, Tensor)), type(args[0]) + constant = args[0].clone() + elif ( + torch.Tag.nondeterministic_seeded not in func.tags + and all_constant + and any_constant + and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out) + ): + # NB: do NOT include factories as constants + with unset_fake_temporarily(): + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + constant = func(*const_args, **const_kwargs) + else: + constant = None + + track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) + _maybe_record_pointwise_barrier(func, proxy_mode) + return out + + +class _SymNodeDict: + """ + Wrapper around a dictionary that will hash SymInts with their nodes + """ + + def __init__(self) -> None: + self.sym_node_dict: dict[PySymType, _PySymProxyType] = {} + + def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None: + self.sym_node_dict[key.node] = value + + def __getitem__(self, key: PySymType) -> _PySymProxyType: + return self.sym_node_dict[key.node] + + def __contains__(self, key: PySymType) -> bool: + return key.node in self.sym_node_dict + + def get( + self, key: PySymType, default: Optional[_PySymProxyType] = None + ) -> _PySymProxyType: + # dict.get()'s annotation doesn't accept `None` when the value type + # isn't Optional. + return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type, return-value] + + def __iter__(self) -> Any: + raise NotImplementedError + + def __len__(self) -> int: + return len(self.sym_node_dict) + + +class PythonKeyTracer(Tracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: _SymNodeDict + sympy_expr_tracker: dict[sympy.Symbol, object] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + torch_fn_counts: dict[OpOverload, int] + enable_thunkify: bool = False + stack_trace: bool = False + + def __init__(self) -> None: + super().__init__(autowrap_modules=()) # type: ignore[arg-type] + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + self.sympy_expr_tracker = dict() + + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + self.enable_thunkify = False + + # In general, we don't want to make modules leaves. In principle, users of + # this tracer might want to override this in order to turn a couple specific + # modules into leaves in the traced graph. + def call_module( + self, + m: Module, + forward: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + return forward(*args, **kwargs) + + # We don't want to turn getattr calls into proxies. So we just return the actual value. + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] + ) -> object: + return attr_val + + def create_arg(self, a: object) -> fx.node.Node: + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + + qualname = self.get_fresh_qualname("_param_constant") + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + elif isinstance(a, py_sym_types): + assert a.node.constant is not None + return a.node.constant + return super().create_arg(a) # type: ignore[return-value] + + @overload + def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ... + + @overload + def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ... + + @overload + def unwrap_proxy( + self, e: _AnyScriptObjectType + ) -> Union[Proxy, _AnyScriptObjectType]: ... + + def unwrap_proxy(self, e: T) -> object: + if isinstance(e, Tensor): + return get_proxy_slot(e, self, e, lambda x: x.proxy) + elif isinstance(e, py_sym_types): + return get_proxy_slot(e, self, e, lambda e: e.force()) + elif isinstance(e, _AnyScriptObject): + return get_proxy_slot(e, self, e) + else: + return e + + def create_node( + self, + kind: str, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> torch.fx.Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type] + + # stack_trace + if ( + self.stack_trace + and "stack_trace" not in node.meta + and node.op not in ["placeholder", "output"] + ): + user_frame_summary = CapturedTraceback.extract().summary() + if user_frame_summary: + # we retain frames from forward() calls, or ops + # located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap) + stack_trace = [ + frame + for frame in user_frame_summary + if ( + frame.name == "forward" + or frame.filename.endswith("torch/__init__.py") + ) + ] + # filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py + # this is hardcoded, but leads to a much cleaner stack trace + stack_trace = [ + frame + for frame in stack_trace + if not frame.filename.endswith( + ("fx/_symbolic_trace.py", "export/_trace.py") + ) + ] + if ( + stack_trace + ): # empty list for strict mode, dynamo should handle stack_trace + stack_trace = traceback.StackSummary.from_list(stack_trace) + node.meta["stack_trace"] = "".join(stack_trace.format()).strip() + + if kind == "get_attr": + assert isinstance(target, str) + attr = getattr(self.root, target) + if isinstance(attr, torch.Tensor): + with disable_proxy_modes_tracing(): + node.meta["val"] = extract_val(attr) + + def map_fn(v: Any) -> Optional[_ExtractValType]: + if not isinstance(v, torch.fx.Node) or "val" not in v.meta: + return None + val = v.meta["val"] + # other subclasses like FunctionalTensor error on `extract_val` + # "Attempting to use FunctionalTensor on its own." just store FakeTensors for now + if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor): + return None + return extract_val(v.meta["val"]) + + if _should_save_eager_input_vals(target, (args, kwargs)): + # NOTE "eager_input_vals" + # We save the original (args, kwargs) FakeTensor values for nodes + # that have exact stride requirements. This is useful downstream. + # We use this information inside Inductor to ensure that inputs to + # stride-sensitive operators have the correct strides. + arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] + node.meta["eager_input_vals"] = (arg_inp, kwarg_inp) + + return node + + +def _should_save_eager_input_vals( + target: Any, + args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None, +) -> bool: + from torch._higher_order_ops.invoke_subgraph import InvokeSubgraphHOP + + if not callable(target): + return False + if isinstance( + target, + ( + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + InvokeSubgraphHOP, + ), + ): + return True + if args_kwargs is not None and ( + target is torch.ops.higher_order.auto_functionalized + or target is torch.ops.higher_order.auto_functionalized_v2 + ): + args = args_kwargs[0] + assert isinstance( + args[0], (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ) + return _should_save_eager_input_vals(args[0], None) + if target is torch.ops.higher_order.with_effects: + # TODO: inductor lowering for with_effects needs to be updated to propagate + # the arg_kwarg_vals + return False + if isinstance(target, torch._ops.HigherOrderOperator): + if pytree.tree_any(_should_save_eager_input_vals, args_kwargs): + raise RuntimeError( + f"NYI: The HOP {target} has an input that is an OpOverload that " + f"needs exact strides. We probably need special logic to " + f"propagate the FakeTensor vals. Please file an issue." + ) + if isinstance(target, torch._ops.OpOverload): + from torch._library.utils import get_layout_constraint_tag + + return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides + return False + + +def _make_temp_remove_mode_context_manager( + mode_ty: type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode + + temp_elements = [] + removed_mode = None + + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + try: + yield removed_mode + + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + return context_manager_fn + + +@torch._disable_dynamo +def dispatch_trace( + root: Union[Module, Callable], + tracer: Tracer, + concrete_args: Optional[tuple[Any, ...]] = None, +) -> GraphModule: + graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] + + # NB: be careful not to DCE .item() calls + def impure_pred(n: fx.Node) -> bool: + from .symbolic_shapes import is_accessor_node + + # Always defer to the built-in notion of impure + if n.is_impure(): + return True + + # Accessors always OK to DCE + if is_accessor_node(n): + return False + + # If the operator in question takes SymInt args to SymInt output, + # we assume it's pure and OK to DCE + if ( + isinstance(n.meta.get("val"), py_sym_types) + and + # NB: constant args ok + all( + isinstance(a.meta.get("val"), py_sym_types) + for a in n.args + if isinstance(a, fx.Node) + ) + ): + return False + + # No idea, just assume it's not OK + return True + + graph.eliminate_dead_code(impure_pred) + from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints + + dedupe_symints(graph) + name = root.__class__.__name__ if isinstance(root, Module) else root.__name__ + return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name) + + +def wrap_key( + f: Callable[[Unpack[_Ts]], R], + tensors: tuple[Unpack[_Ts]], + tracer: _ProxyTracer, + pre_dispatch: bool, +) -> Callable[_P, R]: + flat_tensors, _tensors_spec = pytree.tree_flatten(tensors) + + @functools.wraps(f) + def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: + flat_proxies, _proxies_spec = pytree.tree_flatten(proxies) + assert len(flat_proxies) == len(flat_tensors) + with disable_proxy_modes_tracing() as m: + assert isinstance(m, ProxyTorchDispatchMode) + track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) + + def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: + return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined] + + out = f(*tensors) # type:ignore[call-arg] + out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out) + out = pytree.tree_map_only( + _AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out + ) + + def get_sym_proxy_slot(t: PySymType) -> Proxy: + return get_proxy_slot(t, tracer).force() + + out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out) + return out + + return wrapped + + +# TODO: Make downstream users of this work with OperatorBase +ORIGINAL_ATEN: Optional[object] = None + + +@contextmanager +def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]: + global ORIGINAL_ATEN + if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): + ORIGINAL_ATEN = func + fx_traceback.current_meta["original_aten"] = func + try: + yield + finally: + ORIGINAL_ATEN = None + fx_traceback.current_meta["original_aten"] = None + else: + yield + + +class TorchFunctionMetadataMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + + def __torch_function__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + self.tracer.torch_fn_metadata = func + self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1 + return func(*args, **kwargs) + + +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + +# This mode is **only** used for pre_dispatch tracing. +# In particular, we need to make sure that autograd/autocast API's +# that do not desugar into dispatcher operators stay in the graph. +class PreDispatchTorchFunctionMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + # The input to torch.amp.autocast_mode._exit_autocast graph node should be the + # enter_autocast node. So we have to save the enter autocast node here, and assign it + # to the exit_autocast call_function node. + self.enter_autocast_nodes: list[torch.fx.Node] = [] + + def __torch_function__( + self, + func: Union[OpOverload, Callable], + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + if func in _side_effectful_need_to_be_preserved_pre_dispatch: + # It's for passing the export verifier which needs to verify the meta['val'] + # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, + # instead of hardcoding it here. + # T203648563 + if func == torch.amp.autocast_mode._exit_autocast: + enter_node = self.enter_autocast_nodes.pop() + args = (enter_node,) + node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type] + if func == torch.amp.autocast_mode._enter_autocast: + self.enter_autocast_nodes.append(node) + if func in [ + torch._C._set_grad_enabled, + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + ]: + node.meta["val"] = None + return node + # Don't actually run the function! We just want to trace the calls + # into a graph. We don't actualy want to change global autograd state. + return func(*args, **kwargs) + + +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + +class ProxyTorchDispatchMode(TorchDispatchMode): + # Ensure this is read-only; this exists only for legacy reasons + @property + def enable_tracing(self) -> bool: + return True + + def __init__( + self, + tracer: _ProxyTracer, + tracing_mode: str, + pre_dispatch: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, + ) -> None: + dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None + super().__init__(dk) + self.tracer = tracer + self.tracing_mode = tracing_mode + self.pre_dispatch = pre_dispatch + self._allow_fake_constant = _allow_fake_constant + self._error_on_data_dependent_ops = _error_on_data_dependent_ops + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.PROXY + # Every time we enter a mode, we maintain a stack telling us what the previous + # ProxyTorchDispatchMode state was (if there was any). + # This lets us properly reset the state on exit. + self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = [] + self.decomp_layers: int = 0 + from torch._inductor import config + + self.emulate_precision_casts: bool = config.emulate_precision_casts + + @count + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + with set_original_aten_op(func): + kwargs = kwargs or {} + + if func in (prim.device.default,): + return func(*args, **kwargs) + + return proxy_call(self, func, self.pre_dispatch, args, kwargs) + + def __enter__(self) -> Self: + # Stash and store the previous proxy mode (there may or may not be one) + maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + self.enter_stack.append(maybe_prev_proxy_mode) + return super().__enter__() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: + b = super().__exit__(exc_type, exc_value, traceback) + + # Re-enable the previous proxy mode, if there was one. + mb_previous_proxy_mode = self.enter_stack.pop() + if mb_previous_proxy_mode is not None: + _push_mode(mb_previous_proxy_mode) + + return b + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + def _compute_proxy( + self, func: OpOverload, args: tuple[object, ...], out: PySymType + ) -> Proxy: + # Handle torch.sym_sum + n_args: tuple[object, ...] + if len(args) == 1 and isinstance(args[0], (list, tuple)): + n_args = ( + tuple( + ( + get_proxy_slot(a, self.tracer).force().node + if isinstance(a, py_sym_types) + else a + ) + for a in args[0] + ), + ) + else: + n_args = tuple( + ( + get_proxy_slot(a, self.tracer).force().node + if isinstance(a, py_sym_types) + else a + ) + for a in args + ) + + # func doesn't have a __torch_function__ that Proxy can interpose, so + # we gotta do it manually + n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type] + p_out = fx.Proxy(n_out, self.tracer) + set_meta(p_out, out) + return p_out + + def __sym_dispatch__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + # Peephole optimize multiply by one + # NB: be careful not to trigger guards here! + if func == operator.mul: + if isinstance(args[1], int) and args[1] == 1: + return args[0] + elif isinstance(args[0], int) and args[0] == 1: + return args[1] + + # For speed, we assume there are no nested data structures + # (otherwise we could use tree_map) + # We also assume there are no keyword arguments. + assert not kwargs + out = func(*args, **kwargs) + + # If func returned a constant, we don't need to trace; we have + # determined that the result is constant (no matter if the inputs + # were symbolic) and it is no longer necessary to trace the + # computation. This could occur if func triggered some guards. + if isinstance(out, py_sym_types): + p_out_thunk = thunkify( + self.tracer, self._compute_proxy, func=func, args=args, out=out + ) + set_proxy_slot(out, self.tracer, p_out_thunk) + + return out + + +class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: MutableMapping[PySymType, _PySymProxyType] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + sympy_expr_tracker: dict[sympy.Symbol, object] + torch_fn_metadata: Optional[OpOverload] + torch_fn_counts: dict[OpOverload, int] + enable_thunkify: bool = False + + def __init__(self, graph: fx.graph.Graph) -> None: + super().__init__(graph) + self.symnode_tracker = weakref.WeakKeyDictionary() + self.tensor_tracker = WeakTensorKeyDictionary() + self.sympy_expr_tracker = {} + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + + +# TODO: I'm not sure what the point of this class is; you can just +# make_fx through a regular Interpreter +class DecompositionInterpreter(fx.Interpreter): + def __init__( + self, + module: fx.GraphModule, + new_graph: fx.Graph, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + **kwargs: object, + ) -> None: + super().__init__(module, **kwargs) # type: ignore[arg-type] + self.new_graph = new_graph + self.tracer = _GraphAppendingTracerEx(self.new_graph) + # Blegh + self.decomposition_table = decomposition_table or {} + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + def placeholder( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + # TODO handle case where the first character of target is '*' + return out + + def get_attr( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + # call_function, call_method, call_module get traced automatically by the outer mode. + + def output( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().output(target, args, kwargs) # type: ignore[arg-type] + + def get_proxy_node(x: _ProxyTensor) -> fx.node.Node: + return x.proxy.node + + def unwrap(e: Tensor) -> Union[Tensor, fx.Node]: + return get_proxy_slot(e, self.tracer, e, get_proxy_node) + + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args: object, **kwargs: object) -> object: + # Should enter the mode at least once for being able to restore it later + # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 + with decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) # type: ignore[arg-type] + + +def wrapper_and_args_for_make_fx( + func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object] +) -> tuple[Callable[[list[object]], R], list[object]]: + # make_fx doesn't support kwargs, so we need to do this flattening + # and then unflatten the args before calling func + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def wrapped(flat_args: list[object]) -> R: + fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) + return func(*fn_args, **fn_kwargs) + + return wrapped, flat_args + + +@contextmanager +def disable_autocast_cache() -> Generator[None, None, None]: + old_value = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + try: + yield + finally: + torch.set_autocast_cache_enabled(old_value) + + +class _ModuleNotInstalledAsSubmoduleError(NameError): + pass + + +# Base class for inline _ModuleStackTracer.__init__.AttrProxy +class _AttrProxy: + def reset_proxy_mapping(self, base: Module, path: str) -> None: + pass + + +class _ModuleStackTracer(PythonKeyTracer): + r"""Customized version of PythonKeyTracer that retains module stack + information in node.meta["nn_module_stack"]. + + FX symbolic trace actually does this already, but it relies on `self.root` + being the actual module being traced. Since make_fx traces a lambda of our + creation, things don't work properly. + + So for this version we hold onto a reference to the original module + (scope_root) and use that to match the path. Also when we see, + A + / \ + B C + \ / + D + we want to record the path as A.B.D by recording only one path. + See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605 + """ + + def __init__(self, scope_root: GraphModule) -> None: + super().__init__() + self.stack_trace = True + self.scope_root = scope_root + self.enable_attr_proxy = False + self.submodule_paths = {} + for name, m in self.scope_root.named_modules(remove_duplicate=False): + if m in self.submodule_paths: + log.info( + "Shared module found between %s and %s, AttrProxy is enabled.", + self.submodule_paths[m], + name, + ) + self.enable_attr_proxy = True + else: + self.submodule_paths[m] = name + + self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() + self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary() + self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary() + self.counter = 0 + + self.module_id_cache = defaultdict(list) + for name, mod in self.scope_root.named_modules(remove_duplicate=False): + self.module_id_cache[id(mod)].append(name) + + # Build a wrapper around _AttrProxy to provide the tracer. We can't + # store it on _AttrProxy itself beceause we mimic the underlying class + # (including its attributes). + tracer = self + + class AttrProxy(_AttrProxy): + def __init__(self, base: Union[Module, _AttrProxy], path: str) -> None: + if isinstance(base, _AttrProxy): + base = base.get_base() # type: ignore[attr-defined] + + assert isinstance(base, Module) + # Class is modified to be a subclass of torch.nn.Module + # Warning: We blow away our own attributes here to mimic the base class + # - so don't expect `self.x` to do anything useful. + self.__class__ = type( + base.__class__.__name__, + (self.__class__, base.__class__), + {}, + ) + self.__dict__ = base.__dict__ + self.__class__.__module__ = base.__class__.__module__ + self.__class__.__qualname__ = base.__class__.__qualname__ + + # This overwrites any existing paths if `base` is an AttrProxy + tracer.proxy_paths[self] = path + tracer.proxy_modules[self] = base + + def __getattr__(self, name: str) -> AttrProxy: + assert isinstance(self, Module) + # Calling into torch.nn.Module.__getattr__ with super(), + # That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py. + # which then calls into _ModuleStackTracer.getattr + attr_val = super().__getattr__(name) # type: ignore[misc] + if not isinstance(attr_val, Module): + return attr_val + + return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name) + + def get_base(self) -> Module: + return tracer.proxy_modules[self] + + def __getitem__(self, idx: Union[int, slice]) -> AttrProxy: + if isinstance(idx, slice): + if isinstance(self, torch.nn.Sequential): + # Copied from nn/modules/container.py + res = torch.nn.Sequential( + OrderedDict(list(self._modules.items())[idx]) + ) + return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") + elif isinstance(self, torch.nn.ModuleList): + # Copied from nn/modules/container.py + res = torch.nn.ModuleList(list(self._modules.values())[idx]) + return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") + + return super().__getitem__(idx) # type: ignore[misc] + + @property + def _modules(self) -> dict[str, AttrProxy]: + assert "_modules" in self.__dict__ + submodules = self.__dict__["_modules"] + assert isinstance(submodules, dict) + return { + key: ( + AttrProxy(value, tracer.proxy_paths[self] + "." + str(key)) # type: ignore[misc] + if value is not None + else value + ) + for key, value in submodules.items() + } + + self.proxy_type = AttrProxy + + def path_of_module(self, mod: Module) -> str: + """ + Use tracked access path during tracing instead of the default BFS behavior. + Still use all the possible module paths to verify the result. + """ + if mod is self.scope_root: + return "" + + if isinstance(mod, _AttrProxy): + return self.proxy_paths[mod] + + try: + return Tracer.path_of_module(self, mod) + except NameError as e: + raise _ModuleNotInstalledAsSubmoduleError from e + + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] + ) -> object: + if ( + not isinstance(attr_val, Module) + or isinstance(attr_val, fx.GraphModule) + or not self.enable_attr_proxy + ): + return super().getattr(attr, attr_val, parameter_proxy_cache) + if isinstance(attr_val, _AttrProxy): + return attr_val + + # See NOTE [caching AttrProxy]. + if attr_val not in self.attr_proxy_map: + self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr) + else: + self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr) + return self.attr_proxy_map[attr_val] + + def trace( # type: ignore[override] + self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]] + ) -> fx.Graph: + res = super().trace(root, concrete_args) + + # Since we are making _AttrProxy mimic the original + # submodule, when someone registers a module directly + # to the tracer while tracing, the proxy object gets registered + # first. So we need to replace the proxy modules with the real ones + # This can happen during HOO tracing + proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = [] + for name, module in self.root.named_modules(): + if module in self.proxy_modules: + proxy_module_names_to_be_replaced.append((name, module)) + + def _delete_proxy_attr(obj: Module, target: str) -> bool: + # Copied from fx/graph_module.py + # Customized it for proxy type + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + assert isinstance(obj, Module) + mod = obj + + # Get the parent module + for item in path: + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, (_AttrProxy, Module)): + return False + + if not hasattr(mod, target_submod): + return False + + # At least the leaf module should be proxy type. + if not isinstance(getattr(mod, target_submod), _AttrProxy): + return False + + delattr(mod, target_submod) + return True + + for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced: + _delete_proxy_attr(self.root, proxy_module_name) + actual_module = self.proxy_modules[proxy_module] + _assign_attr(actual_module, self.root, proxy_module_name) + + return res + + def call_module( + self, + m: Module, + forward: Callable, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> None: + """PythonKeyTracer overrides call_module to avoid the scope handling, + but we actually want it. + """ + from torch._dynamo import OptimizedModule + + # FIXME (tmanlaibaatar) + # When we call torch.compile inside HOO, we will end up + # invoking a module that is not registered on the root. For + # now, we just inline them. But once we start supporting + # mark_strict in export, we do need to properly handle this. + # Right now, it doesn't matter because current non-strict + # use cases don't need to work with HOO. + if isinstance(m, (OptimizedModule, GraphModule)): + return forward(*args, **kwargs) + + try: + return Tracer.call_module(self, m, forward, args, kwargs) + except _ModuleNotInstalledAsSubmoduleError: + log.debug( + "Unable to find the path of the module %s. " + "This might be because the module was not properly registered " + "as a submodule, which is not good practice. We will trace " + "through the module without recording stack information.", + str(m), + ) + return forward(*args, **kwargs) + + def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool: + return False + + def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: + """ + Create node and add on metadata. + Add nn_module_stack here instead of TracerBase, + since calls to make_fx() might not want to record module stack metadata. + Add torch_fn by looking at torch_fn_metadata and torch_fn_counts. + Add stack_trace by filtering out forward() stack frames. + """ + node = super().create_node(*args, **kwargs) # type: ignore[arg-type] + + # nn_module_stack + if node.op not in ["placeholder", "output"]: + if "nn_module_stack" not in node.meta: + node.meta["nn_module_stack"] = self.module_stack + # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] + for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): + if isinstance(mod_cls, type): + node.meta["nn_module_stack"][key] = ( + fqn, + mod_cls.__module__ + "." + mod_cls.__qualname__, + ) + + # torch_fn + if ( + node.op == "call_function" + and self.torch_fn_metadata is not None + and "torch_fn" not in node.meta + ): + node.meta["torch_fn"] = ( + f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}", + f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}", + ) + + return node + + +class _MakefxTracer: + def __init__( + self, + decomposition_table: Optional[Mapping[OpOverload, Callable]], + tracing_mode: str, + _allow_non_fake_inputs: bool, + pre_dispatch: bool, + record_module_stack: bool, + _allow_fake_constant: bool, + _error_on_data_dependent_ops: bool, + stack_trace: bool = False, + ) -> None: + # Configurations that are used to initialize the context managers and their states. + # Should not modify them during tracing. + self.decomposition_table: dict[OpOverload, Callable] = dict( + decomposition_table or {} + ) + self.decomposition_table.setdefault( + torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel + ) + self.tracing_mode: str = tracing_mode + self._allow_non_fake_inputs: bool = _allow_non_fake_inputs + self.pre_dispatch: bool = pre_dispatch + self.record_module_stack: bool = record_module_stack + self._allow_fake_constant: bool = _allow_fake_constant + self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops + + # All context managers and their states should be initialized before tracing based on the inputs + # and configurations. After tracing, their states should be cleaned except for shape_env. + # Remember to specify how to initialize it from user inputs and from parent tracer whenever + # adding new modes in _MakefxTracer. + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() + self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = ( + nullcontext() + ) + self.fx_tracer: Optional[PythonKeyTracer] = None + self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() + self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = ( + nullcontext() + ) + self.stack_trace = stack_trace + + def _checkpoint_modes(self) -> list[Any]: + return [ + self.fake_tensor_mode, + self.proxy_mode, + self.proxy_function_mode, + self.fx_tracer, + self.python_dispatcher_mode, + self.torch_fn_metadata_mode, + ] + + def _restore_modes( + self, + prev_fake_tensor_mode: Optional[FakeTensorMode], + prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode], + prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode], + prev_fx_tracer: Optional[PythonKeyTracer], + prev_python_dispatcher_mode: Union[nullcontext, Any], + prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode], + ) -> None: + self.fake_tensor_mode = prev_fake_tensor_mode + self.proxy_mode = prev_proxy_mode + self.proxy_function_mode = prev_proxy_function_mode + self.fx_tracer = prev_fx_tracer + self.python_dispatcher_mode = prev_python_dispatcher_mode + self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode + + @contextmanager + def _init_modes_from_inputs( + self, f: Callable, args: tuple[object, ...] + ) -> Generator[None, None, None]: + prev_modes = self._checkpoint_modes() + try: + # Avoid importing sympy at a module level + from .symbolic_shapes import ShapeEnv + + if hasattr(f, "_orig_mod") and self.record_module_stack: + scope_root = f._orig_mod + # _ModuleStackTracer always try to preserve stack trace + self.fx_tracer = _ModuleStackTracer(scope_root) + else: + self.fx_tracer = PythonKeyTracer() + self.fx_tracer.stack_trace = self.stack_trace + + if self.tracing_mode == "fake": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=ShapeEnv(), + static_shapes=True, + ) + self.fake_tensor_mode = fake_tensor_mode + elif self.tracing_mode == "symbolic": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + shape_env = ShapeEnv() + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=shape_env, + ) + assert fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) + self.fake_tensor_mode = fake_tensor_mode + else: + if not self.tracing_mode == "real": + raise AssertionError( + f"Unexpected tracing type: {self.tracing_mode}" + ) + + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None: + self.proxy_mode = ProxyTorchDispatchMode( + fx_tracer, + self.tracing_mode, + pre_dispatch=self.pre_dispatch, + _allow_fake_constant=self._allow_fake_constant, + _error_on_data_dependent_ops=self._error_on_data_dependent_ops, + ) + + if self.pre_dispatch: + self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer) + + # pre-autograd tracing uses per-dispatch-key modes, + # which requires the python dispatcher + if self.tracing_mode == "symbolic" or self.pre_dispatch: + self.python_dispatcher_mode = enable_python_dispatcher() + + self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer) + + @contextmanager + def _init_modes_from_parent( + self, parent_tracer: _MakefxTracer + ) -> Generator[None, None, None]: + # By default, subtracer creates new modes based on parent tracer's config. + # However, there are cases where we want to share the same modes with parent tracer + # For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same. + prev_modes = self._checkpoint_modes() + try: + self.fake_tensor_mode = parent_tracer.fake_tensor_mode + + def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer: + if type(parent_tracer) == PythonKeyTracer: + return PythonKeyTracer() + elif type(parent_tracer) == _ModuleStackTracer: + return _ModuleStackTracer(parent_tracer.scope_root) + else: + raise RuntimeError( + f"Unexpected tracer type: {type(parent_tracer)}." + ) + + assert parent_tracer.fx_tracer is not None + self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer) + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _trace_inner(self, f: Callable, *args: object) -> GraphModule: + # TODO: We need to explicitly import torch._dynamo before calling dispatch_trace, + # because dispatch_trace will introduce the lazy import of torch._dynamo, + # and some contexts set before calling dispatch_trace will cause problems with the import of torch._dynamo, + # such as some torch API(torch.ones and so on) in populate_builtin_to_tensor_fn_map() will be affected + # by the context set before dispatch_trace. + import torch._dynamo + + phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args) + + def _wrap_fake(args: T) -> T: + arg_count = 0 + + def inner_wrap_fake(x: object) -> object: + nonlocal arg_count + # TODO: it would be nice to line these up with the names + # FX will choose for the placeholders, but we don't + # actually know what the names will be at this point yet + # NB: the Source here is actually meaningless + from torch._dynamo.source import ConstantSource + + assert self.fake_tensor_mode is not None + source = ConstantSource(f"input{arg_count}") + if isinstance(x, Tensor): + arg_count += 1 + return self.fake_tensor_mode.from_tensor(x, source=source) + # NB: don't match on bools + elif type(x) is int and self.tracing_mode == "symbolic": + assert self.fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) + return self.fake_tensor_mode.shape_env.create_symintnode( + self.fake_tensor_mode.shape_env.create_symbol( + x, source, positive=None + ), + hint=x, + source=source, + ) + elif isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + self.fake_tensor_mode, x + ) + + assert not isinstance(x, FakeScriptObject), ( + f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + ) + return x + + wrap_fn_map = { + "real": lambda x: x, + "fake": inner_wrap_fake, + "symbolic": inner_wrap_fake, + } + return pytree.tree_map(wrap_fn_map[self.tracing_mode], args) + + def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: + if ( + not hasattr(inspect.unwrap(f), "__code__") + or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS + ): + # FX doesn't support varargs, so we gotta fake up a wrapper + # TODO: Would be nice to fix this at the source... + return fake_signature(f, len(phs)) + return f + + args = _wrap_fake(args) + func = _wrap_func(f, phs) + # We disable the autocast cache as the autocast cache causes type conversions on parameters to + # check a cache, which introduces untracked tensors into the graph + # + # We also disable tracing by any other tensor proxy-based tracers except the current. The + # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is + # thus irrelevant to any external functional trace. + proxy_mode: ProxyTorchDispatchMode = typing.cast( + ProxyTorchDispatchMode, self.proxy_mode + ) + with ExitStack() as stack: + stack.enter_context(decompose(self.decomposition_table)) + if self.fake_tensor_mode: + stack.enter_context(self.fake_tensor_mode) + stack.enter_context(self.python_dispatcher_mode) + stack.enter_context(self.proxy_function_mode) + stack.enter_context(self.torch_fn_metadata_mode) + stack.enter_context(proxy_mode) + stack.enter_context(disable_autocast_cache()) + stack.enter_context(_set_make_fx_tracer(self)) + + assert self.fx_tracer is not None + try: + t = dispatch_trace( + wrap_key(func, args, self.fx_tracer, self.pre_dispatch), + tracer=self.fx_tracer, + concrete_args=tuple(phs), + ) + except Exception: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "make_fx_fail_partial", + "encoding": "string", + }, + payload_fn=lambda: self.fx_tracer.graph.python_code( # type: ignore[union-attr] + root_module="self", + verbose=True, + include_stride=True, + include_device=True, + ).src, + ) + raise + + # TODO: kind of a bad way to do it, should maybe figure out a better way + if self.tracing_mode == "symbolic": + assert self.fake_tensor_mode is not None + t.shape_env = self.fake_tensor_mode.shape_env # type: ignore[assignment] + return t + + def trace(self, f: Callable, *args: object) -> fx.GraphModule: + with self._init_modes_from_inputs(f, args): + return self._trace_inner(f, *args) + + def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: + # Create a new tracer based on parent's config + sub_tracer = _MakefxTracer( + self.decomposition_table, + "real", + self._allow_non_fake_inputs, + self.pre_dispatch, + self.record_module_stack, + self._allow_fake_constant, + self._error_on_data_dependent_ops, + ) + with sub_tracer._init_modes_from_parent(self): + return sub_tracer._trace_inner(f, *args) + + +_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None + + +@contextmanager +def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]: + global _CURRENT_MAKE_FX_TRACER + prev_tracer = _CURRENT_MAKE_FX_TRACER + try: + _CURRENT_MAKE_FX_TRACER = tracer + yield + finally: + _CURRENT_MAKE_FX_TRACER = prev_tracer + + +def make_fx( + f: Callable, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + tracing_mode: str = "real", + _allow_non_fake_inputs: bool = False, + *, + pre_dispatch: bool = False, + record_module_stack: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, + stack_trace: bool = False, +) -> Callable[..., GraphModule]: + """ + Given a function f, return a new function which when executed with valid + arguments to f, returns an FX GraphModule representing the set of operations that + were executed during the course of execution. + + If stack_trace is True, the stack_trace will be preserved on node.meta["stack_trace"] + """ + + assert tracing_mode in ["real", "fake", "symbolic"] + + from torch._inductor import config + + make_fx_tracer = _MakefxTracer( + decomposition_table, + tracing_mode, + _allow_non_fake_inputs, + pre_dispatch, + record_module_stack, + _allow_fake_constant, + _error_on_data_dependent_ops, + stack_trace=stack_trace or config.trace.enabled, + ) + + @functools.wraps(f) + def wrapped(*args: object) -> GraphModule: + return make_fx_tracer.trace(f, *args) + + return wrapped + + +def get_torch_dispatch_modes() -> list[TorchDispatchMode]: + return torch.utils._python_dispatch._get_current_dispatch_mode_stack() + + +# TODO: this is a legacy name, there is only ever one proxy mode as it's an +# infra mode +def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + return get_proxy_mode() + + +def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + """ + Current the currently active proxy tracing mode, or None if + we are not currently tracing. This includes pre-dispatch proxy + tracing. + """ + pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch( + torch._C._TorchDispatchModeKey.PROXY + ) + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + assert pre_dispatch_mode is None or mode is None, ( + f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + ) + return pre_dispatch_mode or mode + + +def handle_sym_dispatch( + func: Callable[_P, R], + args: _P.args, # type: ignore[valid-type] # not allowed to use _P.args here + kwargs: _P.kwargs, # type: ignore[valid-type] # not allowed to use _P.kwargs here +) -> R: + """ + Call into the currently active proxy tracing mode to do a + SymInt/SymFloat/SymBool dispatch trace on a function that operates on + these arguments. + """ + mode = get_proxy_mode() + assert mode + # Have to do it manually, because we're not doing the normal torch + # dispatch machinery which disables it for us + with disable_proxy_modes_tracing(): + # TODO: properly compute types + types: list[type] = [] + return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] + + +@contextmanager +def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]: + return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + + +def maybe_handle_decomp( + proxy_mode: ProxyTorchDispatchMode, + op: OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + from torch._inductor.compiler_bisector import CompilerBisector + + if op in CURRENT_DECOMPOSITION_TABLE: + if CompilerBisector.disable_subsystem( + "aot_eager_decomp_partition", "decomposition", lambda: repr(op) + ): + return NotImplemented + + with proxy_mode: + proxy_mode.decomp_layers += 1 + out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) + proxy_mode.decomp_layers -= 1 + return out + + return NotImplemented + + +def get_isolated_graphmodule( + func: Callable, + args: tuple[object, ...], + kwargs: dict[str, object], + tracing_mode: str = "real", + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, +) -> GraphModule: + """A helper function used to get the GraphModule for the given func. + + It's expected to be used in the ProxyTensor tracing context. + It detaches the args and kwargs from the current tracer so that the trace of + the current graph module can be created without any side-effects. + """ + wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) + + with disable_proxy_modes_tracing(): + gm = make_fx( + wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode + )(all_args) + return gm + + +def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: + """A helper function for setting up unbacked_bindings on the destination FX graph.""" + from .symbolic_shapes import compute_unbacked_bindings + + # Can't use detect_fake_mode here, + # + # python test/distributed/_tensor/test_dtensor_compile.py -k + # test_tp_compile_fullgraph_is_seq_parallel_False + # + # will fail. Very strange, it probably isn't right for them to be using + # two fake modes there... + fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + if fake_mode and fake_mode.shape_env: + if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): + assert isinstance(out_proxy, Proxy), out_proxy + out_proxy.node.meta["unbacked_bindings"] = symbol_to_path diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/recording.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/recording.py new file mode 100644 index 0000000000000000000000000000000000000000..a9025fc54ebe3e415653c11c37cad5a197cc8cb1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/recording.py @@ -0,0 +1,529 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import itertools +import logging +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.utils._pytree as pytree + + +log = logging.getLogger(__name__) +trace_shape_events_log = torch._logging.getArtifactLogger( + __name__, "trace_shape_events" +) + + +__all__ = [ + "ShapeEnvEvent", + "record_shapeenv_event", + "replay_shape_env_events", + "FakeTensorMeta", + "shape_env_check_state_equal", + "NotEqualError", +] + +# [Note: Recording ShapeEnv Events] +# ================================= +# +# What is a ShapeEnv event? +# ------------------------- +# We consider a ShapeEnv event every function call (ShapeEnv method or +# independent function) that modifies the state of the ShapeEnv instance. +# Such calls are recorded alongside their positional and keyword arguments, +# so that it may be replayed over a different ShapeEnv instance. +# +# See [Note: ShapeEnv State Equality] for what is considered the state +# of a ShapeEnv instance. +# +# What is it for? +# --------------- +# ShapeEnv events recording is used for reconstructing the ShapeEnv in an +# arbitrary state in time. +# +# Being able to arbitrarily replay events like so is useful, mainly for +# translation validation bisection. i.e. if a ValidationException has been +# raised, find the earliest point in time where the translation validation +# fails. +# +# Besides that, it also allows us to inspect the given instance and, +# for example, check the guards that would actually be issued at that point. +# +# What kind of arguments can be stored in an event? +# ------------------------------------------------- +# There's no specific rule for what cannot be used as an argument. +# That said, pay special attention to the following cases: +# +# 1. Tensor inputs: there are some tests that check whether the inputs +# were garbage collected after execution. These will fail if there's +# an event that is holding a reference to those inputs. +# +# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that +# will be automatically replaced by the new given ShapeEnv instance. +# +# 3. SymTypes arguments: they also hold references to ShapeEnv. So, +# whenever we see them, we create a new instance, replacing the +# ShapeEnv reference. +# +# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic +# shapes. That argument must be replaced when replaying the event at +# ShapeEnvEvent.run, since it has to reference a node from the given +# instance, and not from the recorded instance. + + +# Event class for reconstructing ShapeEnv at arbitrary time. +# +# Represents a method call that mutates ShapeEnv in a way that affects the +# issued guards, when ShapeEnv.produce_guards is called. +@dataclass +class ShapeEnvEvent: + # ShapeEnv method. + f: Callable + + # Arguments and keyword arguments called with. + args: Optional[list[Any]] = None + kwargs: Optional[dict[str, Any]] = None + + # List of tracked_fakes at the time the method was called. + tracked_fakes: Optional[list[Any]] = None + + # Name of the captured event. + # Used for special handling of particular methods. + name: Optional[str] = None + + # Replay itself, but using shape_env as self. + def run(self, shape_env=None) -> Any: + from torch.fx.experimental.symbolic_shapes import ( + is_symbolic, + ShapeEnv, + SymTypes, + ) + + # Special handling for the constructor event. + if self.f is ShapeEnv: + assert shape_env is None and self.args is None and self.kwargs is not None + return ShapeEnv(**self.kwargs) + + assert shape_env is not None + args = list(self.args or []) + kwargs = dict(self.kwargs or {}) + + # Replace any argument of type ShapeEnv by the given one. + args, kwargs = pytree.tree_map_only( + ShapeEnv, lambda _: shape_env, (args, kwargs) + ) + + # Replace any argument of type SymTypes by a new instance, + # replacing its ShapeEnv reference. + args, kwargs = pytree.tree_map_only( + lambda x: isinstance(x, SymTypes) and is_symbolic(x), + lambda a: type(a)(a.node.with_shape_env(shape_env)), + (args, kwargs), + ) + + # Converts FX nodes using the mapping argument. + def maybe_convert_node(x: Any) -> Any: + if not isinstance(x, torch.fx.Node): + # Don't do anything to x if it's not an FX node. + return x + + # If, at some point, we created an FX node, it means that translation validation is on. + # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and + # we are tracking node names at shape_env.name_to_node. + assert hasattr(shape_env, "name_to_node") + name_to_node = shape_env.name_to_node # type: ignore[attr-defined] + assert x.name in name_to_node + return name_to_node[x.name] + + # Replaces the value of an specific argument by the result of fn. + def replacearg(index: int, key: str, fn: Callable): + if index < len(args): + args[index] = fn(args[index]) + if key in kwargs: + kwargs[key] = fn(kwargs[key]) + + if self.is_create_fx_call_function(): + # ShapeEnv.create_fx_call_function: + # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv. + # They must be replaced, since a "call_function" FX node with this tuple as argument + # will be added to the FX graph of the new shape_env. + replacearg( + index=2, + key="args", + fn=lambda args: tuple(maybe_convert_node(a) for a in args), + ) + if self.is_evaluate_expr() or self.is_defer_runtime_assert(): + # ShapeEnv.evaluate_expr and ShapeEnv.guard_or_defer_runtime_assert: + # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. + # They must be replaced, since it will be part of a "call_function" FX node for + # torch._assert, which will be added to the FX graph of the new shape_env. + replacearg(index=3, key="fx_node", fn=maybe_convert_node) + + # Actually call the method with the converted arguments. + return self.f(*args, **kwargs) + + def __str__(self) -> str: + name = self.name if self.name is not None else self.f.__name__ + return f"event: {name} ({self.args}, {self.kwargs})" + + def is_create_fx_call_function(self) -> bool: + return self.name == "_create_fx_call_function" + + def is_evaluate_expr(self) -> bool: + return self.name == "evaluate_expr" + + def is_defer_runtime_assert(self) -> bool: + return self.name == "guard_or_defer_runtime_assert" + + +NEST = 0 + + +# Extracts a ShapeEnv instance inside args and kwargs. +# Specifically, it looks for: +# 1. ShapeEnv arguments +# 2. SymInt, SymFloat, or SymBool arguments +# If we find more than one object of any of the above types, we +# also check that the ShapeEnv instance is the same for all of them. +def _extract_shape_env_and_assert_equal(args, kwargs): + from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes + + def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: + if old is not None: + assert old is new, "call with different ShapeEnv" + return new + + shape_env = None + for val in itertools.chain(args, kwargs.values()): + if isinstance(val, ShapeEnv): + shape_env = assert_equal(shape_env, val) + if isinstance(val, SymTypes) and is_symbolic(val): + shape_env = assert_equal(shape_env, val.node.shape_env) + + return shape_env + + +# Decorator for recording the given function as a replayable event. +# +# This decorator should be used at every function that mutates the state of +# ShapeEnv in some way that affects the resulting issued guards (i.e. when +# ShapeEnv.produce_guards is called). +# +# save_tracked_fakes: saves a snapshot of the TrackedFake list. +# This is used when calling ShapeEnv.produce_guards at arbitrary points in time. +# +# name: the name of the function being recorded. Normally (and by default) this +# is taken from the decorated function but can be set if you need to override +# it. +# +# When to save the list of TrackedFake? +# ===================================== +# We should save the list of TrackedFake whenever the translation validation +# bisection may actually stop and call the produce_guards method at the moment +# right after the recorded function was played. In other words, since the +# bisection bisects through torch._assert calls, we should save in all methods +# that adds a torch._assert call to the symbolic shapes FX graph. +# +# At the moment, there are 2 methods that save the list: +# - ShapeEnv.evaluate_expr +# - ShapeEnv.guard_or_defer_runtime_assert +def record_shapeenv_event( + *, save_tracked_fakes: bool = False, name: Optional[str] = None +) -> Callable: + def decorator(fn: Callable) -> Callable: + assert callable(fn) + args = inspect.getfullargspec(fn).args + assert args and args[0] == "self", ( + "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " + "code so that it calls into a method on ShapeEnv" + ) + nonlocal name + if name is None: + name = fn.__name__ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + assert isinstance(args[0], ShapeEnv) + + global NEST + + trace_shape_events_log.debug( + "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs + ) + NEST += 1 + + def retlog(r): + trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r) + return r + + shape_env = args[0] + + try: + if not shape_env.should_record_events or shape_env.is_recording: # type: ignore[has-type] + # If ShapeEnv is already recording an event, call the wrapped + # function directly. + # + # NB: here, we skip the check of whether all ShapeEnv instances + # are equal, in favor of a faster dispatch. + return retlog(fn(*args, **kwargs)) + + # Retrieve an instance of ShapeEnv. + # Assumption: the collection of args and kwargs may not reference + # different ShapeEnv instances. + self = _extract_shape_env_and_assert_equal(args, kwargs) + + # If we are calling this function without any ShapeEnv instance + # alive in its arguments, we don't record and call the original. + if self is None: + return retlog(fn(*args, **kwargs)) + + # Otherwise, start recording and call the function. + with self._recording(): + # Take a snapshot of the current tracked_fakes. + tracked_fakes = ( + self._snapshot_tracked_fakes() if save_tracked_fakes else None + ) + # Record the event for 'fn'. + event = ShapeEnvEvent( + fn, + list(args), + kwargs, + tracked_fakes, + name=name, + ) + # Play the event on this ShapeEnv. + # NB: It's important to put the event first, because running + # the event can trigger internal events that must be ordered + # after this event. However, if an exception happens, we do + # NOT want to have the event in the list, so pop it off from + # the record if an error happened + self.events.append(event) + try: + return retlog(event.run(self)) + except Exception: + self.events.pop() + raise + + except Exception: + if not shape_env.should_record_events or shape_env.is_recording: + # If ShapeEnv is disabled or already recording an event, re-raise the exception without logging. + raise + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) + raise + + finally: + NEST -= 1 + + return wrapper + + return decorator + + +# Replays the ShapeEnvEvents list. +# It assumes the first event is the constructor call. +# +# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. +def replay_shape_env_events(events): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + constructor_event = events[0] + assert constructor_event.f == ShapeEnv + + # Constructs the new ShapeEnv. + shape_env = constructor_event.run() + + for event in events[1:]: + try: + # Actually replays each event. + # We need to call create_mapping_fn every time, since the node list might + # change after each event is replayed. + event.run(shape_env) + except Exception: + log.error("failed when running event: %s", event) + raise + + return shape_env + + +# FakeTensor metadata. +# This is to be used in place of FakeTensor placeholders when calling +# ShapeEnv.produce_guards. +@dataclass +class FakeTensorMeta: + tensor_size: tuple[Union[int, torch.SymInt], ...] + tensor_stride: tuple[Union[int, torch.SymInt], ...] + tensor_storage_offset: Union[int, torch.SymInt] + is_nested: bool + + def size(self) -> tuple[Union[int, torch.SymInt], ...]: + return self.tensor_size + + def stride(self) -> tuple[Union[int, torch.SymInt], ...]: + return self.tensor_stride + + def storage_offset(self) -> Union[int, torch.SymInt]: + return self.tensor_storage_offset + + def dim(self) -> int: + return len(self.tensor_size) + + @staticmethod + def from_fake(fake) -> "FakeTensorMeta": + return FakeTensorMeta( + fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested + ) + + +# [Note: ShapeEnv State Equality] +# =============================== +# +# What is considered ShapeEnv state? +# ---------------------------------- +# We consider to be the state of a ShapeEnv instance everything that +# is not in the inline tuple inside remove_nonstate_variables function. +# That is: the fields within ShapeEnv that modify the flow of execution +# of the program. +# +# So, for example: the replacements field might influence on how an +# expression is simplified. That, in turn, may result in a guard being +# statically known (i.e. not added). +# +# On the other hand, var_to_stack serves only changes what is printed +# in the screen, i.e. used only for debugging purposes. Therefore, we +# should not consider it when comparing states. +# +# What to do on NotEqualError? +# ---------------------------- +# Here are a few possible causes for getting a NotEqualError raised: +# +# 1. New field that does not belong in the ShapeEnv state. +# For example: log field of type ShapeEnvLoggerAdapter. Different +# ShapeEnv instances will always have different ShapeEnvLoggerAdapter +# instances, i.e. equality comparison would fail. +# Solution: add it to the inlined tuple inside remove_nonstate_variables +# function inside check_equal method. +# +# 2. New field that is not directly comparable across instances. +# For example: guards field of type List[ShapeGuard]. More specifically, +# the ShapeGuard type holds an expression and a stack information +# for debugging purposes. When replaying the even on a new ShapeEnv +# instance, the stack would be different, which would trigger this error. +# Solution: add a special case to the map_value function inside +# check_equal function. +# +# 3. Mutation of ShapeEnv on some not recorded function. +# If a mutation of the state of ShapeEnv happens inside a function +# that is not recorded (or that no caller in the stack is recorded), +# then, the replayed ShapeEnv won't catch that. +# Solution: decorate the function with record_shape_env_event. + + +# Checks whether the state of two ShapeEnv are equal w.r.t. the guards +# returned by ShapeEnv.produce_guards. +def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): + # Collect and remove variables that don't necessarily represent the state + # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the + # instance itself. + env1_vars = vars(env1).copy() + env2_vars = vars(env2).copy() + + for v in non_state_variable_names: + if v in env1_vars: + env1_vars.pop(v) + if v in env2_vars: + env2_vars.pop(v) + + # Function for transforming the mismatched values into string. + # Needed, since dict and set entries order might not be the same every time. + def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return ( + "{" + + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str)) + + "}" + ) + if isinstance(value, set): + return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}" + return str(value) + + # Compares env1_vars with env2_vars. + # Here, we allow the value of each field to be mapped, so that we appropriately + # compare the two values. + def compare_vars( + map_value: Callable[[str, Any], Any], + ) -> list[tuple[str, str, str]]: + env1_set, env2_set = set(env1_vars), set(env2_vars) + + # First, compare the set of keys in each vars dictionary. + if env1_set != env2_set: + raise NotEqualError( + "field set mismatch:", + [ + ( + "found unique fields:", + str(sorted(env1_set - env2_set)), + str(sorted(env2_set - env1_set)), + ), + ], + ) + + # Then, sort the keys, and compare the mapped values of each key. + sorted_keys = list(env1_set) + sorted_keys.sort() + + mapped_dict = [ + (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k])) + for k in sorted_keys + ] + + # Return a list of tuples representing the fields that did not match + # alongside their respective mapped values. + return [ + (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2)) + for k, val1, val2 in mapped_dict + if val1 != val2 + ] + + # Accumulate the mismatching fields. + errors = compare_vars(map_value) + + if len(errors) > 0: + raise NotEqualError("field values don't match:", errors) + + +class NotEqualError(Exception): + def __init__( + self, + msg: str, + mismatched: list[tuple[str, str, str]], + ) -> None: + details = "\n".join( + [ + "\n".join( + [ + f"==> {inner_msg}", + f" > Left: {str1}", + f" > Right: {str2}", + ] + ) + for inner_msg, str1, str2 in mismatched + ] + ) + + super().__init__( + f"""\ +ShapeEnv not equal: {msg} + +{details} +""" + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..8e92163a2139caab2fd2a690d810f52073e75644 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/refinement_types.py @@ -0,0 +1,16 @@ +class Equality: + def __init__(self, lhs: object, rhs: object): + self.lhs = lhs + self.rhs = rhs + + def __str__(self) -> str: + return f"{self.lhs} = {self.rhs}" + + def __repr__(self) -> str: + return f"{self.lhs} = {self.rhs}" + + def __eq__(self, other: object) -> bool: + if isinstance(other, Equality): + return self.lhs == other.lhs and self.rhs == other.rhs + else: + return False diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..8e635a525f6f09c2759c8d3fa105068f70ac6094 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/rewriter.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import ast +import copy +import functools +import inspect +import textwrap +from types import FunctionType +from typing import Any, Callable, cast, Optional, Union + +import torch +from torch._sources import normalize_source_lines +from torch.fx._symbolic_trace import Tracer +from torch.fx.graph import Graph + + +class AST_Rewriter(ast.NodeTransformer): + """ + Take a FunctionType object representing a `forward` method, then + perform an AST rewrite to swap out nodes that are not symbolically + traceable with a callsite to the FX alternative. + + To support swapping out an AST node, define a new `visit` method on + that node. For more details, see: + https://docs.python.org/3/library/ast.html#ast.NodeTransformer + """ + + # This function checks for new keys added in the globals dict. TorchDynamo + # can insert new keys in the global dict and upset the check. Therefore, put + # a disable here. This function is an optimization pass and not really + # suitable for dynamo tracing anyways. + @torch._dynamo.disable + def rewrite(self, fn: FunctionType): + # Normalize the source lines + sourcelines, _ = inspect.getsourcelines(fn) + sourcelines = normalize_source_lines(sourcelines) + source = "".join(sourcelines) + normalized_str = textwrap.dedent(source) + + # Rewrite the original AST + source_ast = ast.parse(normalized_str) + dest_ast = ast.fix_missing_locations(self.visit(source_ast)) + + # Pull out the compiled function from the newly-created Module + code = compile(dest_ast, "", "exec") + globals_dict = copy.copy(fn.__globals__) + keys_before = set(globals_dict.keys()) + exec(code, globals_dict) + new_keys = list(set(globals_dict.keys()) - keys_before) + assert len(new_keys) == 1 + fn_compiled = globals_dict[new_keys[0]] + + # return the compiled function with the original globals + def change_func_globals(f, globals): + """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" + # __globals__ is a private member of the function class + # so we have to copy the function, f, all of its member, except f.__globals__ + g = FunctionType( + f.__code__, + globals, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__, + ) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] + return g + + # Return the correct FunctionType object + return change_func_globals(fn_compiled, globals=fn.__globals__) + + def visit_Assert(self, node): + """ + Swap out the Assert node (Python's `assert`) with a callsite to the + symbolically-traceable torch._assert function + """ + # Create the Call node + n = ast.parse("torch._assert()", mode="eval") + assert isinstance(n, ast.Expression) + call_node = n.body + assert isinstance(call_node, ast.Call) + msg = node.msg if node.msg else ast.Constant(value="", kind=None) + call_node.args = [node.test, msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = ast.Expr(value=call_node) + + # Return the new Call node to signify that we want to use it as + # a replacement for the original _assert node + return ast.copy_location(expr_wrapper, node) + + def visit_AnnAssign(self, node): + """ + Swap out Python's AnnAssign with an Assign node where the annotation function is called. + Example: + Original: + y: Tensor_Type(1,2,3, Dyn) = f2(x) + Output: + y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) + """ + return ast.Assign( + targets=[node.target], + value=ast.Call( + func=ast.Name(id="annotate", ctx=ast.Load()), + args=[node.value, node.annotation], + keywords=[], + ), + ) + + +class RewritingTracer(Tracer): + def trace( + self, + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[dict[str, Any]] = None, + ) -> Graph: + return super().trace(_rewrite(root), concrete_args) + + +def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: + if isinstance(fn, torch.nn.Module): + # Rewrite this module's `forward` as well as the `forward`s of + # all of this module's recursive descendents. Return the new, + # rewritten module hierarchy. + def rewrite_module(m: torch.nn.Module): + class RewrittenModule(torch.nn.Module): + def __init__(self, orig): + super().__init__() + for k, v in orig.__dict__.items(): + if isinstance(v, torch.nn.Module): + self.__dict__[k] = copy.copy(rewrite_module(v)) + else: + self.__dict__[k] = copy.copy(v) + + RewrittenModule.forward = AST_Rewriter().rewrite( + cast(FunctionType, m.forward) + ) + return RewrittenModule(m) + + return rewrite_module(fn) + else: + # Rewrite this single free function + return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b2f1680d64a1ff928a8519dd4d93d61a861a54 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/schema_type_annotation.py @@ -0,0 +1,145 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Any, Optional + +import torch +import torch.fx +from torch._jit_internal import boolean_dispatched +from torch.fx import Transformer +from torch.fx.node import Argument, Target +from torch.fx.operator_schemas import _torchscript_type_to_python_type + + +class AnnotateTypesWithSchema(Transformer): + """ + Use Python function signatures to annotate types for `Nodes` within an FX graph. + This pulls out Python function signatures for: + + 1. Standard `torch.nn` Module calls + 2. `torch.nn.functional` calls + 3. Attribute fetches via `get_attr` + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = AnnotateTypesWithSchema(traced).transform() + + """ + + def __init__( + self, + module: torch.nn.Module, + annotate_functionals: bool = True, + annotate_modules: bool = True, + annotate_get_attrs: bool = True, + ): + super().__init__(module) + self.annotate_functionals = annotate_functionals + self.annotate_modules = annotate_modules + self.annotate_get_attrs = annotate_get_attrs + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + python_ret_type = None + if self.annotate_functionals and target.__module__ == "torch.nn.functional": + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched["if_true"], dispatched["if_false"] + # TODO: can we emit the union of these? What are the implications on TorchScript + # compilation? + if ( + inspect.signature(if_true).return_annotation + != inspect.signature(if_false).return_annotation + ): + return super().call_function(target, args, kwargs) + target_for_analysis = if_true + + python_ret_type = self._extract_python_return_type(target_for_analysis) + + return_proxy = super().call_function(target, args, kwargs) + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) + return return_proxy + + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + python_ret_type = None + assert isinstance(target, str) + submod = self.fetch_attr(target) + if self.annotate_modules and hasattr(submod.__class__, "__name__"): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + python_ret_type = self._extract_python_return_type(submod.forward) + return_proxy = super().call_module(target, args, kwargs) + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) + return return_proxy + + def get_attr( + self, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Any], + ): + attr_proxy = super().get_attr(target, args, kwargs) + + if self.annotate_get_attrs: + module_itr = self.module + assert isinstance(target, str) + atoms = target.split(".") + for i, atom in enumerate(atoms): + if not hasattr(module_itr, atom): + raise RuntimeError( + f"Node referenced nonextent target {'.'.join(atoms[:i])}!" + ) + module_itr = getattr(module_itr, atom) + + maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) + if maybe_inferred_ts_type.success(): + python_type = _torchscript_type_to_python_type( + maybe_inferred_ts_type.type() + ) + attr_proxy.node.type = ( + python_type if not attr_proxy.node.type else attr_proxy.node.type + ) + + return attr_proxy + + def _extract_python_return_type(self, target: Target) -> Optional[Any]: + """ + Given a Python call target, try to extract the Python return annotation + if it is available, otherwise return None + + Args: + + target (Callable): Python callable to get return annotation for + + Returns: + + Optional[Any]: Return annotation from the `target`, or None if it was + not available. + """ + assert callable(target) + try: + sig = inspect.signature(target) + except (ValueError, TypeError): + return None + + return ( + sig.return_annotation + if sig.return_annotation is not inspect.Signature.empty + else None + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py new file mode 100644 index 0000000000000000000000000000000000000000..5468191163ab73a9ca4f7a97b75829daf8e00d9d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py @@ -0,0 +1,1847 @@ +# mypy: allow-untyped-defs + +from __future__ import annotations + + +""" +This file does three things: +- Contains the definition of SymNode +- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time +- Does not depend on sympy at import time + +As this file is imported from within torch/__init__.py we do not want it to depend on SymPy +to avoid having to load SymPy at import time, as doing so is *very* slow. +""" + + +import builtins +import functools +import inspect +import itertools +import logging +import math +import operator +import sys +from functools import lru_cache, update_wrapper +from typing import Optional, TYPE_CHECKING, Union + +import torch +import torch._logging.structured as structured + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import ( # noqa: F401 + sym_float, + sym_ite, + sym_max, + sym_min, + sym_not, + SymBool, + SymFloat, + SymInt, +) +from torch._logging import dtrace_structured + + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + +log = logging.getLogger(__name__) +sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") + + +__all__ = ["SymNode", "method_to_operator", "magic_methods"] + + +from torch.types import py_sym_types as SymTypes + + +def _to_symtype(t): + if t is bool: + return SymBool + if t is int: + return SymInt + if t is float: + return SymFloat + return t + + +# TODO: An incomplete list +# 1. Set variables to be equal when we do equality +# 2. Specialize on 0/1 when we do subtraction +class SymNode: + """ + This is a type erased SymInt/SymFloat which we use to do actual operations. + End users don't touch this. Magic methods are NOT defined on this object. + """ + + # Note [optimized_summation]: indicates that SymNode is an Add expression of the form + # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations + # for common patterns see _optimized_add. + + # The unfortunate reason we have this here is because sympy sets __slots__ = () for add expression, + # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as + # a weak dictionary key either! So instead, we attach the attribute here to the SymNode. + _optimized_summation: bool = False + + def __init__( + self, + expr, + shape_env, + pytype, + hint: Optional[Union[int, float, bool]], + constant=None, + fx_node=None, + optimized_summation=False, + ): + self._expr = expr + self.shape_env = shape_env + self.pytype = pytype + self._optimized_summation = optimized_summation + + # What's the difference between hint and constant? + # + # - A constant is known to be invariant across invocations of the model; + # it will always be this value. We only really know this when we + # encounter an honest-to-goodness literal (when wrapping it into + # a SymNode, we set constant.) Most of the time, constant is None + # + # - A hint is a *particular* value from the particular run we are + # tracing, but it may vary the next time around. It's useful to + # keep this around, as if we need a concrete value from a SymNode, + # we will return the hint and guard on the expression that produced + # it giving the same hint next time around. The hint is not + # guaranteed to be set either: if you have an unbacked SymNode, + # there won't be any hint; it was the result of some tensor-dependent + # computation, but we don't know what it actually is because we + # haven't actually run the tensor computation. + # + # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) + # in hopes that we've learned enough about the unbacked symints to + # discharge the hint; otherwise, you're likely to just error out. + # + # (A previous version of this system had some optimizations to only + # recompute when it was possible we had learned enough about the + # unbacked symint that a hint was now possible, but as we added more + # potential refinements to unbacked symints this got harder to keep + # in sync, so we've deleted it for now.) + + def compute_hint(): + from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + + # This occasionally gets exercised by, e.g., + # convert_shape_to_symint. It's just a nicety so you don't HAVE + # to have a correct hint on hand when making a SymNode. + # Don't attempt to compute for unbacked, this can be quite + # expensive. + if has_free_unbacked_symbols(self.expr): + return None + hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) + if hint is not None: + hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint + return hint + + if hint is not None: + assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( + "Cannot create SymNode of type " + f"{pytype} with incompatible hint of type {type(hint)}" + ) + if self.shape_env and self.shape_env._translation_validation_enabled: + # This is technically not TV, but this assert is expensive so + # let's only do it when we're already doing expensive things + computed_hint = compute_hint() + assert hint == computed_hint, ( + f"{hint} != {computed_hint} (for {self.expr})" + ) + else: + hint = compute_hint() + self._hint = hint + self.constant: Optional[Union[int, float, bool]] = constant + + # Record the FX node of the current node if we are doing translation + # validation. They will be used for building the input assertions for + # the translation validation problem. + tx_validation_en = ( + self.shape_env and self.shape_env._translation_validation_enabled + ) + self.fx_node = tx_validation_en and fx_node + + def with_shape_env(self, shape_env: ShapeEnv) -> SymNode: + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + def _value_eq(self, other: SymNode) -> bool: + # Purposely don't include the shape_env in the eq. + return ( + self._expr == other._expr + and self.pytype == other.pytype + and self._hint == other._hint + and self.constant == other.constant + and self.fx_node == other.fx_node + ) + + def _value_hash(self) -> int: + # Purposely don't include the shape_env in the hash. + return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) + + @property + def expr(self): + return self.shape_env.replace(self._expr) + + @property + def hint(self): + return self._hint + + def has_hint(self): + return self._hint is not None + + def require_hint(self, fallback=None): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if self._hint is None: + if fallback is not None: + # Say we have some expr like 2*u0 + s0 + # The hint will be None, since the expr contains at least 1 unbacked. + # We will: + # - replace every backed free symbol with its corresponding hint + # - replace every unbacked free symbol with the fallback + # - regenerate the expression with those symbol replacements + # Note: this is not really complete either, since right now + # this logic does not take into account any value ranges + # for the unbacked symints, we may need to beef it up at some point. + unbacked_symbols = free_unbacked_symbols(self.expr) + replacements = { + s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s] + for s in self.expr.free_symbols + } + return self.expr.xreplace(replacements) + # NB: we expect this to raise + return self.shape_env.size_hint(self.expr) + return self._hint + + def maybe_as_int(self): + if self.expr.is_number: + return int(self.expr) + else: + return None + + # NB: This does conversions, not sure if this is good or not + def maybe_as_float(self): + import sympy + + if isinstance(self.expr, sympy.Float): + return float(self.expr) + else: + return None + + def maybe_as_bool(self): + import sympy + + if self.expr is sympy.true: + return True + elif self.expr is sympy.false: + return False + else: + return None + + def is_int(self): + return self.pytype is int + + def is_float(self): + return self.pytype is float + + def is_bool(self): + return self.pytype is bool + + def is_nested_int(self): + # Unbacked SymInts cannot be nested int today + return ( + self._hint is not None + and isinstance(self._hint, SymInt) + and self._hint.node.is_nested_int() + ) + + def wrap_int(self, num): + assert type(num) is int + import sympy + + return SymNode( + sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num + ) + + def wrap_float(self, num): + assert type(num) is float + import sympy + + return SymNode( + sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num + ) + + def wrap_bool(self, num): + assert type(num) is bool + import sympy + + return SymNode( + sympy.true if num else sympy.false, + self.shape_env, + bool, + num, + constant=num, + fx_node=num, + ) + + def clone(self): + return self + + def str(self): + return f"{self.expr}" + + def __str__(self): + return self.str() + + def __repr__(self): + rep = [ + f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}", + ] + if self._hint is not None: + rep.append(f"hint={self._hint}") + if self.constant is not None: + rep.append(f"constant={self.constant}") + if self.fx_node is not None: + rep.append(f"fx_node={self.fx_node}") + return ", ".join(rep) + ")" + + def _graph_repr(self) -> builtins.str: + # Representation used by GraphModule to create a pythonic version of a graph + return self.str() + + # These methods call the metaprogrammed methods, they're hand written + # here so we get good stack traces + def abs(self) -> SymNode: + return self._abs() # type: ignore[attr-defined] + + def pos(self) -> SymNode: + return self._pos() # type: ignore[attr-defined] + + def round(self, ndigits=None) -> SymNode: + return self._round(ndigits) # type: ignore[attr-defined] + + def trunc(self) -> SymNode: + return self._trunc() # type: ignore[attr-defined] + + def add(self, other) -> SymNode: + return self._add(other) # type: ignore[attr-defined] + + def sub(self, other) -> SymNode: + return self._sub(other) # type: ignore[attr-defined] + + def mul(self, other) -> SymNode: + return self._mul(other) # type: ignore[attr-defined] + + def mod(self, other) -> SymNode: + return self._mod(other) # type: ignore[attr-defined] + + def float_pow(self, other) -> SymNode: + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> SymNode: + return self._pow_by_natural(other) # type: ignore[attr-defined] + + def and_(self, other) -> SymNode: + return self._and_(other) # type: ignore[attr-defined] + + def or_(self, other) -> SymNode: + return self._or_(other) # type: ignore[attr-defined] + + def float_truediv(self, other) -> SymNode: + return self._float_truediv(other) # type: ignore[attr-defined] + + def int_truediv(self, other) -> SymNode: + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> SymNode: + return self._int_floordiv(other) # type: ignore[attr-defined] + + def lshift(self, other) -> SymNode: + return self._lshift(other) # type: ignore[attr-defined] + + def rshift(self, other) -> SymNode: + return self._rshift(other) # type: ignore[attr-defined] + + def sym_not(self) -> SymNode: # noqa: F811 + return self._sym_not() # type: ignore[attr-defined] + + def eq(self, other) -> SymNode: + return self._eq(other) # type: ignore[attr-defined] + + def ne(self, other) -> SymNode: + return self._ne(other) # type: ignore[attr-defined] + + def gt(self, other) -> SymNode: + return self._gt(other) # type: ignore[attr-defined] + + def lt(self, other) -> SymNode: + return self._lt(other) # type: ignore[attr-defined] + + def le(self, other) -> SymNode: + return self._le(other) # type: ignore[attr-defined] + + def ge(self, other) -> SymNode: + return self._ge(other) # type: ignore[attr-defined] + + def floor(self) -> SymNode: + return self._floor() # type: ignore[attr-defined] + + def is_integer(self) -> SymNode: + return self._is_integer() # type: ignore[attr-defined] + + def sym_float(self) -> SymNode: # noqa: F811 + return self._sym_float() # type: ignore[attr-defined] + + def sym_int(self) -> SymNode: + return self._sym_int() # type: ignore[attr-defined] + + def ceil(self) -> SymNode: + return self._ceil() # type: ignore[attr-defined] + + def neg(self) -> SymNode: + return self._neg() # type: ignore[attr-defined] + + def sym_min(self, other) -> SymNode: # noqa: F811 + return self._sym_min(other) # type: ignore[attr-defined] + + def sym_max(self, other) -> SymNode: # noqa: F811 + return self._sym_max(other) # type: ignore[attr-defined] + + def sym_ite(self, then_val, else_val) -> SymNode: + return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] + + def is_contiguous(self, sizes, strides) -> SymNode: + return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode: + return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode: + return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_2d(self, sizes, strides) -> SymNode: + return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_3d(self, sizes, strides) -> SymNode: + return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] + + def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode: + return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] + + # Make C++ happy + def sym_or(self, other): + return self.or_(other) + + def sym_and(self, other): + return self.and_(other) + + # Integer bitwise ops + def bitwise_and(self, other): + return self._bitwise_and(other) # type: ignore[attr-defined] + + def bitwise_or(self, other): + return self._bitwise_or(other) # type: ignore[attr-defined] + + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> SymNode: + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + + def is_non_overlapping_and_dense(self, sizes, strides): + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq( + to_node(self, 1) + ) # type: ignore[attr-defined] + + def int_(self): + return self.guard_int("", 0) # NB: uses Python backtrace + + # This one is currently done by hand, but if we add other variadic + # functions consider factoring it out to be metaprogrammed too. Note that + # some load bearing logic is directly in torch.sym_sum + + def sym_sum(self, args) -> SymNode: + import sympy + + # Inner impl + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + if get_proxy_mode(): + return to_node( + self, + handle_sym_dispatch( + torch.sym_sum, + (tuple(wrap_node(a) for a in args),), + {}, + ), + ) + exprs = [a.expr for a in args] + out = sympy.Add(*exprs) + + size_hints = [] + out_hint = None + for a in args: + if a.hint is None: + break + size_hints.append(a.hint) + else: + out_hint = sum(size_hints) + + fx_node, _ = self.shape_env._create_fx_call_function( + torch.sym_sum, (tuple(a.fx_node for a in args),) + ) + + # NB: Only for integers! + return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node) + + def evaluate(self, size_oblivious=False): + return self.shape_env.evaluate_sym_node(self, size_oblivious) + + # You can manually trigger a guard with this function + def guard_int(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return int(r) + except Exception: + log.warning("Failed to convert to int: %s", r) + raise + + def guard_float(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return float(r) + except Exception: + log.warning("Failed to convert to float: %s", r) + raise + + def guard_bool(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def expect_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if ( + self.has_hint() + and not free_unbacked_symbols(self.expr) + and not self.shape_env.prefer_deferred_runtime_asserts_over_guards + ): + # OK to generate guards + return self.guard_bool(file, line) + # Generate a deferred runtime assert (this might actually end up doing + # a regular guard if we can!) + # TODO: file/line here is very important, because the assert has been + # deferred so you can't backtrace easily + return self.shape_env.guard_or_defer_runtime_assert( + self.expr, f"{file}:{line}", fx_node=self.fx_node + ) + + def expect_size(self, file, line): + from torch.fx.experimental.symbolic_shapes import _advise_is_size + + b = self.ge(self.wrap_int(0)) + # Generate a deferred runtime assert + r = b.expect_true(file, line) + # Refine compile time range, but only if it's unbacked. + # If you refine range for hinted variables, you can end up making + # improper deductions since compile time reasoning may be + # incompatible with runtime reasoning. + if r and not self.has_hint(): + _advise_is_size(SymInt(self)) + return r + + def statically_known_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import statically_known_true + + assert self.is_bool() + return statically_known_true(SymBool(self)) + + def guard_size_oblivious(self, file, line): + """ + Like guard_bool, but if we encounter unbacked symbols, if those symbols + are size-like, we will treat them as >= 2 for the purposes of the analysis. + + This CHANGES the runtime semantics, but all size-oblivious sites have been + audited to ensure that the runtime semantics don't change in a material way. + Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping + an unbacked one size, or a tensor reporting as non-contiguous even if it's + contiguous if it would have been reported contiguous due to being empty. + """ + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate(size_oblivious=True) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def guard_or_false(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + assert self.is_bool() + return guard_or_false(SymBool(self)) + + def guard_or_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_true + + assert self.is_bool() + return guard_or_true(SymBool(self)) + + def bool_(self): + return self.guard_bool("", 0) + + def is_symbolic(self): + return True + + def nested_int(self): + return None + + def is_constant(self): + return False + + +# TODO: this probably needs the sizes-strides eval functions +METHOD_TO_OPERATOR = { + "pos": operator.pos, + "abs": operator.abs, + "add": operator.add, + "and": operator.and_, + "bitwise_and": operator.and_, + "ceil": math.ceil, + "eq": operator.eq, + "floor": math.floor, + "trunc": math.trunc, + "int_floordiv": operator.floordiv, + "ge": operator.ge, + "gt": operator.gt, + "is_integer": lambda x: x.is_integer(), + "le": operator.le, + "lshift": operator.lshift, + "lt": operator.lt, + "mod": operator.mod, + "mul": operator.mul, + "ne": operator.ne, + "neg": operator.neg, + "or": operator.or_, + "bitwise_or": operator.or_, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, + "round": builtins.round, + "rshift": operator.rshift, + "sub": operator.sub, + "sym_float": sym_float, + "sym_ite": sym_ite, + "sym_max": sym_max, + "sym_min": sym_min, + "sym_not": sym_not, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, +} + +unary_magic_methods = { + "abs", + "sym_float", + "sym_int", + "ceil", + "floor", + "neg", + "sym_not", + "pos", + "trunc", +} + + +# Adding math ops: sqrt, cos, sin, ... +def _get_sym_node_fn(name): + def fn(self): + return getattr(self, f"_sym_{name}")() + + return fn + + +math_op_names = ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", + "log2", +) +for name in math_op_names: + sym_name = f"sym_{name}" + priv_sym_name = f"_{sym_name}" + setattr(SymNode, sym_name, _get_sym_node_fn(name)) + METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) + unary_magic_methods.add(sym_name) + __all__.append(sym_name) + + +# Unary methods that are not magic methods +unary_nonmagic_methods = { + "is_integer", +} + +unary_methods = unary_magic_methods | unary_nonmagic_methods + +# Most methods are only registered on SymInt and SymFloat +# Some methods are only be registered on SymBool +only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} +# Methods that implicitly convert SymBool into SymInt +bool_becomes_int_magic_methods = {"add", "sub", "mul"} +# Methods that are also on SymBool, in addition to on SymInt and SymFloat +also_bool_magic_methods = {"eq"} +bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods + +# Methods that are only for float +only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"} + + +magic_methods_on_operator_with_trailing_underscore = {"and", "or"} +# remap necessary because an op name can have a bitwise and boolean implementation +bitwise_ops = { + "bitwise_and": "and", + "bitwise_or": "or", +} + + +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} + +for name in math_op_names: + sym_name = f"sym_{name}" + always_float_magic_methods.add(sym_name) + + +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} +always_bool_magic_methods = { + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + "and", + "or", + "sym_not", + "is_non_overlapping_and_dense", + "is_integer", +} + +# Methods that have a `__foo__` as well as `__rfoo__` + + +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv + + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) + + +def _sympy_floordiv(a, b): + from torch.utils._sympy.functions import FloorDiv + + return FloorDiv(a, b) + + +def _sympy_mod(a, b): + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) + + +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural + + return PowByNatural(a, b) + + +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) + + +def _sympy_and(a, b): + import sympy + + return sympy.And(a, b) + + +def _sympy_or(a, b): + import sympy + + return sympy.Or(a, b) + + +def _sympy_lshift(a, b): + from torch.utils._sympy.functions import LShift + + return LShift(a, b) + + +def _sympy_rshift(a, b): + from torch.utils._sympy.functions import RShift + + return RShift(a, b) + + +def _binary_search_insert_arg(ordered_args, new_arg): + """ + If new_arg is found in ordered_args None is returned, else the new + ordered_args with new_arg inserted + """ + if len(ordered_args) == 0: + return [new_arg] + + from sympy.core.basic import _args_sortkey as sort_key, Basic + + # Fast path when new_arg > ordered_args[-1]. + if sort_key(ordered_args[-1]) < sort_key(new_arg): + return ordered_args + [new_arg] + + # Fast path when new_arg < ordered_args[0]. + if sort_key(ordered_args[0]) > sort_key(new_arg): + return [new_arg] + ordered_args + + low, high = 0, len(ordered_args) - 1 + + while low <= high: + mid = (low + high) // 2 + compare_result = Basic.compare(ordered_args[mid], new_arg) + if compare_result == 0: + return None + elif compare_result < 0: + low = mid + 1 + else: + high = mid - 1 + + ordered_args.insert(low, new_arg) + return ordered_args + + +def _optimized_add( + lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False +): + """ + Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea + is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols, + and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following. + 1. Avoid running other optimizations when the Add is constructed. + 2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n) + (comparing terms is expensive and shows in the profiles). + The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols, + (2) the result sympy expression. + """ + import sympy + from sympy.core.basic import _args_sortkey as sortkey + + def make_optimized(ordered_args): + assert ordered_args is not None + result = sympy.Add(*ordered_args, evaluate=False) + return (True, result) + + from torch.utils._sympy.functions import _is_symbols_binary_summation + + lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs) + rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs) + + if lhs_is_optimized_summation and rhs_is_optimized_summation: + # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3) + if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]): + return make_optimized(lhs._args + rhs._args) + # (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3) + if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]): + return make_optimized(rhs._args + lhs._args) + + # (a1+a3) + (a0+a2) => (a0+a1+a2+a3) + if len(lhs._args) <= 2 and len(rhs._args) <= 2: + new_args = list(lhs._args) + for a in rhs._args: + new_args = _binary_search_insert_arg(new_args, a) + if new_args is None: + break + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + # (a0+a2) + a1 => (a0+a1+a2) + if lhs_is_optimized_summation and rhs.is_symbol: + new_args = _binary_search_insert_arg(list(lhs._args), rhs) + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + # a1 + (a0+a2)=> (a0+a1+a2) + if rhs_is_optimized_summation and lhs.is_symbol: + new_args = _binary_search_insert_arg(list(rhs._args), lhs) + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + result = sympy.Add(lhs, rhs) + return (_is_symbols_binary_summation(result), result) + + +def _bitwise_and(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_and + + return BitwiseFn_bitwise_and(a, b) + + +def _bitwise_or(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_or + + return BitwiseFn_bitwise_or(a, b) + + +reflectable_magic_methods = { + "add": _optimized_add, + "sub": operator.sub, + "mul": operator.mul, + "mod": _sympy_mod, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, + "and": _sympy_and, + "bitwise_and": _bitwise_and, + "or": _sympy_or, + "bitwise_or": _bitwise_or, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, + "lshift": _sympy_lshift, + "rshift": _sympy_rshift, +} + + +def _floor_ceil_helper(a, fn): + import sympy + + if isinstance(a, sympy.Mul): + aa = a.args + if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: + coef = sympy.Integer(aa[0]) + if aa[0] == coef: # structural equality test + return coef * aa[1] + if ( + isinstance(a, sympy.Float) + and a == sympy.Integer(a) + or isinstance(a, sympy.Integer) + ): + return sympy.Integer(a) + return fn(a) + + +def _sympy_floor(a): + from torch.utils._sympy.functions import FloorToInt + + return FloorToInt(a) + + +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) +def _sympy_trunc(a): + from torch.utils._sympy.functions import TruncToInt + + return TruncToInt(a) + + +def _sympy_ceil(a): + from torch.utils._sympy.functions import CeilToInt + + return CeilToInt(a) + + +def _sympy_eq(a, b): + import sympy + + return sympy.Eq(a, b) + + +def _sympy_ne(a, b): + import sympy + + return sympy.Ne(a, b) + + +def _sympy_gt(a, b): + import sympy + + return sympy.Gt(a, b) + + +def _sympy_lt(a, b): + import sympy + + return sympy.Lt(a, b) + + +def _sympy_le(a, b): + import sympy + + return sympy.Le(a, b) + + +def _sympy_ge(a, b): + import sympy + + return sympy.Ge(a, b) + + +def _sympy_min(a, b): + from torch.utils._sympy.functions import Min + + return Min(a, b) + + +def _sympy_max(a, b): + from torch.utils._sympy.functions import Max + + return Max(a, b) + + +def _sympy_ite(a, t, f): + import sympy + + return sympy.Piecewise((t, a), (f, True)) + + +current_module = sys.modules[__name__] + + +def _get_sym_math_fn(name): + def fn(a): + import torch.utils._sympy.functions + + return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) + + return fn + + +for name in math_op_names: + priv_sympy_name = f"_sympy_{name}" + fn = _get_sym_math_fn(name) + fn.__qualname__ = fn.__name__ = priv_sympy_name + setattr(current_module, priv_sympy_name, fn) + +del fn, name, priv_sympy_name # type: ignore[possibly-undefined] + + +def _sympy_abs(a): + import sympy + + return sympy.Abs(a) + + +def _sympy_round(number, ndigits=None): + from torch.utils._sympy.functions import RoundDecimal, RoundToInt + + if ndigits is None: + return RoundToInt(number) + else: + return RoundDecimal(number, ndigits) + + +def _sympy_sym_float(a): + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) + + +def _sympy_is_integer(a): + import sympy + + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) + + +magic_methods = { + **reflectable_magic_methods, + "sym_not": operator.invert, + "pos": operator.pos, + "eq": _sympy_eq, + "ne": _sympy_ne, + "gt": _sympy_gt, + "lt": _sympy_lt, + "le": _sympy_le, + "ge": _sympy_ge, + "floor": _sympy_floor, + "trunc": _sympy_trunc, + "sym_float": _sympy_sym_float, + "ceil": _sympy_ceil, + "neg": operator.neg, + "sym_min": _sympy_min, + "sym_max": _sympy_max, + "sym_ite": _sympy_ite, + "abs": _sympy_abs, + "round": _sympy_round, + "is_integer": _sympy_is_integer, +} + + +for name in math_op_names: + sym_name = f"sym_{name}" + magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") + +del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] + + +def sympy_is_contiguous(sizes, strides): + dim = len(sizes) + return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) + + +def sympy_is_contiguous_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if len(dim_order) != dim: + return sympy.false + + is_contiguous = sympy.true + z = sympy.S.One + # Contiguous if the strides make sense (or the dim is size 1) + for d in dim_order: + is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z) + z *= sizes[d] + # OR if any size is zero + for d in range(dim): + is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero) + return is_contiguous + + +# NB: There is a TODO in C++ to allow omitting the batch dim. If that +# happens you will need to refactor this + + +def sympy_is_channels_last_contiguous_2d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_contiguous_3d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): + import sympy + + from torch.utils._sympy.functions import Max + + dim = len(sizes) + + if dim != len(dim_order): + return sympy.false + + m = sympy.S.Zero + r = sympy.true + + # special case for trivial C dimension. default to NCHW + r &= sympy.Ne(strides[1], 0) + + for d in dim_order: + r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) + # Fallback to NCHW as default layout for ambiguous cases + # This is the flaw of implicit memory_format from strides. + # N111 tensor with identical strides for size 1 dimension; + # Two cases could lead us here: + # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + # b. N11W contiguous Tensor sliced on the W-dimension. + # ([N,1,1,1]@[W,W,W,W]) + if d == 0: + r &= sympy.Ne(m, strides[1]) + # This is necessary to: + # 1. distinguish the memory_format of N1H1; + # [H, 1, 1, 1] channels_last stride + # [H, H, 1, 1] contiguous stride + # 2. permutation of 1C1W: + # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as + # channels_last + m = strides[d] * Max(sizes[d], 1) + + return r + + +def sympy_is_channels_last_strides_2d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_strides_3d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): + from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator + + return IsNonOverlappingAndDenseIndicator(*sizes, *strides) + + +sizes_strides_methods = { + # TODO: These could also be done with indicators, maybe it is better + # for reasoning to do it that way + "is_contiguous": sympy_is_contiguous, + "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, + "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, + "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, + "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, + "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, +} + + +def to_node(self, num): + if isinstance(num, SymTypes): + return num.node + elif type(num) is bool: + return self.wrap_bool(num) + elif type(num) is int: + return self.wrap_int(num) + elif type(num) is float: + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + + +def wrap_node(x): + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: + return x.constant + if x.is_int(): + return SymInt(x) + elif x.is_float(): + return SymFloat(x) + elif x.is_bool(): + return SymBool(x) + else: + raise AssertionError(f"unrecognized return type {x}") + + +def method_to_operator(method): + return METHOD_TO_OPERATOR[method] + + +def _make_node_magic(method, func): + func = lru_cache(256)(func) + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + + def uninteresting_files() -> set[str]: + import torch + + mods = [ + torch._dynamo.eval_frame, + torch._dynamo.utils, + torch.fx.experimental.sym_node, + torch, + ] + import torch._dynamo.guards + + return ( + {inspect.getfile(m) for m in mods} + | torch._dynamo.guards.uninteresting_files() + | {""} + ) + + def capture_provenance(fn): + @functools.wraps(fn) + def wrapper(self, other=None): + if other is None: + result = fn(self) + else: + result = fn(self, other) + if torch._logging._internal.GET_DTRACE_STRUCTURED: + if other is not None: + arguments = [self, other] + else: + arguments = [self] + + def get_id(sym_node) -> Optional[int]: + # We don't want to return an ID if the input is a constant + import sympy + + if sym_node.constant is not None: + return None + elif id(sym_node) == id(result): + return None + elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)): + return None + elif sym_node.expr in (sympy.true, sympy.false): + return None + return id(sym_node) + + dtrace_structured( + "expression_created", + metadata_fn=lambda: { + "method": method, + "result": str(result), + "result_id": id(result), + "arguments": [str(a) for a in arguments], + "argument_ids": [ + get_id(i) for i in arguments if get_id(i) is not None + ], + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + }, + ) + + return result + + return wrapper + + @capture_provenance + def binary_magic_impl(self, other): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = method_to_operator(method) + + out_hint = None + if self.hint is not None and other.hint is not None: + out_hint = op(self.hint, other.hint) + + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) + ) + assert isinstance(other, SymNode) + optimized_summation = False + try: + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + elif method == "add": + # see Note [optimized_summation] + (optimized_summation, out) = func( + self.expr, + other.expr, + self._optimized_summation, + other._optimized_summation, + ) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) + raise + sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) + pytype: type + # This is not strictly correct. In Python, a**b may return complex when + # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This + # returns a float while both arguments are ints: 2**(-1). Also, max and + # min do not type promote. To avoid having data-dependent control flow + # here, we just set the type to float if one of the args is a float. In + # case of a type mismatch, we assume that it will be detected during + # evaluation. + if method in always_float_magic_methods: + pytype = float + elif method in always_bool_magic_methods: + pytype = bool + elif self.pytype is float or other.pytype is float: + pytype = float + else: + pytype = self.pytype + + if ( + pytype is not None + and out_hint is not None + and not isinstance(out_hint, SymTypes) + ): + out_hint = pytype(out_hint) + + # Create a FX node that corresponds to the operation being applied to + # this node. + fx_node, _ = self.shape_env._create_fx_call_function( + op, (self.fx_node, other.fx_node) + ) + + result = SymNode( + out, + self.shape_env, + pytype, + out_hint, # type: ignore[arg-type] + fx_node=fx_node, + optimized_summation=optimized_summation, # see Note [optimized_summation] + ) + return result + + @capture_provenance + def unary_magic_impl(self): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = method_to_operator(method) + if get_proxy_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) + # TODO: consider constant prop here + expr = self.expr + if method == "floor" or method == "ceiling": + expr = self.shape_env._simplify_floor_div(expr) + + try: + out = func(expr) + except Exception: + log.warning("failed to eval %s(%s)", method, expr) + raise + sym_node_log.debug("%s %s -> %s", func, expr, out) + out_hint = None + if self.hint is not None: + out_hint = op(self.hint) + pytype: type + if method in always_int_magic_methods: + pytype = int + elif method in always_bool_magic_methods: + pytype = bool + elif method in always_float_magic_methods: + pytype = float + else: + pytype = self.pytype + + fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + if method in unary_methods: + setattr(SymNode, f"_{method_attr}", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_impl(pred_node, then_node, else_node): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + out_hint = then_node.hint if pred_node.hint else else_node.hint + if get_proxy_mode(): + return to_node( + pred_node, + handle_sym_dispatch( + sym_ite, + ( + wrap_node(pred_node), + wrap_node(then_node), + wrap_node(else_node), + ), + {}, + ), + ) + + try: + out = func(pred_node.expr, then_node.expr, else_node.expr) + except Exception: + log.warning( + "failed to eval %s(%s, %s, %s)", + method, + pred_node.expr, + then_node.expr, + else_node.expr, + ) + raise + + fx_node, _ = pred_node.shape_env._create_fx_call_function( + sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) + ) + return SymNode( + out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node + ) + + setattr(SymNode, f"_{method_attr}", sym_ite_impl) + elif method == "round": + + def round_impl(self, ndigits=None): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = builtins.round + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) + ) + + expr = self.expr + try: + out = func(expr, ndigits) + except Exception: + log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) + raise + + if ndigits is None: + pytype = int + else: + pytype = self.pytype + + out_hint = None + if self.hint is not None: + out_hint = op(self.hint, ndigits) + + # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the + # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here + # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The + # hack down below works, because all round function down the line all take ndigits=None as default in their + # signature. + # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL + args = [self.fx_node] + if ndigits is not None: + args.append(ndigits) + fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + setattr(SymNode, f"_{method_attr}", round_impl) + else: + setattr(SymNode, f"_{method_attr}", binary_magic_impl) + + +def _make_node_sizes_strides(method, func): + # NB: don't LRU cache, lots of arguments + + def sizes_strides_impl(self, sizes, strides): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = getattr(sys.modules[__name__], method) + if get_proxy_mode(): + return to_node( + self, + handle_sym_dispatch( + op, + ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), + {}, + ), + ) + size_exprs = [s.expr for s in sizes] + stride_exprs = [s.expr for s in strides] + try: + out = func(size_exprs, stride_exprs) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) + raise + # bool is never expandable + + size_hints = [] + out_hint = None + for s in sizes: + if s.hint is None: + break + size_hints.append(s.hint) + else: + stride_hints = [] + for s in strides: + if s.hint is None: + break + stride_hints.append(s.hint) + else: + out_hint = op(size_hints, stride_hints) + + # NB: This is the indicator function, not the actual bool! + pytype: type + if method.endswith("_indicator"): + pytype = int + else: + pytype = bool + return SymNode(out, self.shape_env, pytype, out_hint) + + setattr(SymNode, f"_{method}", sizes_strides_impl) + + # TODO: This is technically hotpath, but in the ideal end state + # guards on this will resolve at a higher level so you never + # spend time in this code + def sizes_strides_user(sizes, strides): + import sympy + + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + for a in itertools.chain(sizes, strides): + if isinstance(a, SymInt): + return wrap_node( + getattr(a.node, method)( + [to_node(a.node, b) for b in sizes], + [to_node(a.node, b) for b in strides], + ) + ) + if method == "is_non_overlapping_and_dense_indicator": + return eval_is_non_overlapping_and_dense(sizes, strides) + else: + # TODO: this is an awful implementation + return bool( + func( + [sympy.sympify(a) for a in sizes], + [sympy.sympify(a) for a in strides], + ) + ) + + # Skip for is_non_overlapping_and_dense_indicator + if not hasattr(sys.modules[__name__], method): + setattr(sys.modules[__name__], method, sizes_strides_user) + + +for method, func in magic_methods.items(): + _make_node_magic(method, func) + +for method, func in sizes_strides_methods.items(): + _make_node_sizes_strides(method, func) + + +def _make_user_magic(method, user_type): + # User magic takes care of wrapping the other operand into a node, + # so that our internal logic can assume everything is nodes + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"sym_{method}" + else: + method_attr = method + + def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): + if isinstance(x, (int, float, bool)): + return x + if isinstance(x, SymBool): + return x.node.guard_bool("", 0) + raise AssertionError("expect to be called with constant SymBools") + + def is_constant(x): + if isinstance(x, (int, float, bool)): + return True + if isinstance(x, (SymInt, SymFloat, SymBool)): + return x.node.is_constant() + return False + + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + + if method in bool_becomes_int_magic_methods: + + def promote(x): + """Implements True+True=2, which works in python but not sympy""" + if isinstance(x, SymBool): + return SymInt(x.node.wrap_int(int(x))) + return x + + else: + + def promote(x): + return x + + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + + # Before and after performing the operation, check if any operands are constant. + # If so, extract out the constant values first. If `self` itself is a + # constant, then "redispatch" by calling back into the operator. Sometimes + # this means that operations involving SymBool return plain bools. + # Alternatively, we could also rewrap into constant Symbool (i.e. by + # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that + # today for no particular reason. + def unary_magic_impl(self): + self = promote(self) + if is_constant(self): + return (method_to_operator(method))(get_constant(self)) + return wrap_node(getattr(self.node, method_attr)()) + + def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + sym_node_log.debug("MAGIC %s %s %s", method, self, other) + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(self.node, method_attr)(other_node)) + return get_constant(ret) if is_constant(ret) else ret + + def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(other_node, method_attr)(self.node)) + return get_constant(ret) if is_constant(ret) else ret + + if method in unary_magic_methods: + setattr(user_type, f"__{method}__", unary_magic_impl) + elif method in unary_nonmagic_methods: + orig = getattr(user_type, method) + setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) + elif method == "sym_ite": + + def sym_ite_magic_impl(pred, then_val, else_val): + pred_node = pred.node + then_node = to_node(pred_node, then_val) + else_node = to_node(pred_node, else_val) + if then_node is NotImplemented or else_node is NotImplemented: + return NotImplemented + assert ( + isinstance(then_node, SymNode) + and isinstance(else_node, SymNode) + and then_node.pytype == else_node.pytype + ) + ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) + return get_constant(ret) if ret.node.is_constant() else ret + + setattr(user_type, f"__{method}__", sym_ite_magic_impl) + elif method == "round": + + def round_magic_impl(self, ndigits=None): + if is_constant(self): + return builtins.round(get_constant(self), ndigits) + + return wrap_node(getattr(self.node, method)(ndigits)) + + setattr(user_type, f"__{method}__", round_magic_impl) + else: + method_name = method + if method in bitwise_ops: + method_name = bitwise_ops[method] + setattr(user_type, f"__{method_name}__", binary_magic_impl) + if method in reflectable_magic_methods: + setattr(user_type, f"__r{method_name}__", rbinary_magic_impl) + + +for method, func in magic_methods.items(): # type: ignore[assignment] + if method in only_bool_magic_methods: + _make_user_magic(method, SymBool) + continue + if method in only_float_magic_methods: + _make_user_magic(method, SymFloat) + continue + if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: + _make_user_magic(method, SymBool) + _make_user_magic(method, SymInt) + if method not in bitwise_ops: + _make_user_magic(method, SymFloat) + +del method +del func diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..10e17ab9f3e94749e82f8288c0486d1ea68c2040 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py @@ -0,0 +1,8055 @@ +from __future__ import annotations + +import sympy +from sympy import S + +from torch._prims_common import BoolLike, FloatLike, IntLike + + +""" +``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with +our symbolic shapes reasoning system that is used heavily in torch.compile. Although +this is not generally considered public API, when writing framework code in PyTorch +as well as extensions to PyTorch (e.g., in custom operator implementations), you may +need to make use of these APIs to setup dynamic shapes support appropriately. +""" + +import abc +import atexit +import collections +import dis +import functools +import hashlib +import inspect +import itertools +import logging +import math +import operator +import os +import re +import sys +import threading +import traceback +from collections import Counter, defaultdict +from collections.abc import Generator, Iterator, Mapping, Sequence +from contextlib import _GeneratorContextManager, contextmanager +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import ( + Any, + Callable, + cast, + Generic, + NamedTuple, + NoReturn, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import deprecated, ParamSpec, TypeAlias, TypeGuard + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import SymBool, SymFloat, SymInt +from torch._guards import ShapeGuard, SLoc, Source, TracingContext +from torch._logging import dtrace_structured, LazyString, structured, trace_structured +from torch._subclasses.meta_utils import is_sparse_any +from torch._utils_internal import signpost_event +from torch.fx.experimental import _config as config +from torch.fx.experimental.recording import ( + FakeTensorMeta, + record_shapeenv_event, + replay_shape_env_events, + shape_env_check_state_equal, + ShapeEnvEvent, +) +from torch.fx.experimental.sym_node import SymNode, SymTypes +from torch.types import py_sym_types +from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._sympy.functions import ( + Application, + CeilToInt, + CleanDiv, + FloorDiv, + FloorToInt, + IntTrueDiv, + IsNonOverlappingAndDenseIndicator, + Max, + Min, + Mod, + PythonMod, + TruncToInt, +) +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import CppPrinter, PythonPrinter +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRangeError, + ValueRanges, +) +from torch.utils._traceback import CapturedTraceback, format_frame + + +if TYPE_CHECKING: + import types + + from torch import Tensor + from torch._dynamo.source import TensorPropertySource + from torch._subclasses.fake_tensor import FakeTensor + from torch.types import BoolLikeType, FloatLikeType, IntLikeType + + +InputList = list +DimList = list + +log = logging.getLogger(__name__) + + +class GuardOnDataDependentSymNode(RuntimeError): + cond: sympy.Basic + + def __init__(self, cond: sympy.Basic, *args: Any) -> None: + super().__init__(*args) + self.cond = cond + + +class PendingUnbackedSymbolNotFound(RuntimeError): + pass + + +aten = torch._ops.ops.aten # type: ignore[has-type] + +__all__ = [ + "guard_or_false", + "guard_or_true", + "has_symbolic_sizes_strides", + "create_contiguous", + "ShapeEnv", + "is_concrete_int", + "is_concrete_float", + "is_concrete_bool", + "has_static_value", + "guard_int", + "guard_float", + "guard_scalar", + "canonicalize_bool_expr", + "hint_int", + "SYMPY_INTERP", + "free_symbols", + "is_symbol_binding_fx_node", + "is_nested_int", + "SHAPEENV_EVENT_KEY", + "CURRENT_NODE_KEY", + "has_free_symbols", + "has_free_unbacked_symbols", + "sym_and", + "sym_eq", + "sym_or", + "SymbolicContext", + "StatelessSymbolicContext", + "StatefulSymbolicContext", + "SubclassSymbolicContext", + "SymIntSymbolicContext", + "TrackedFake", + "statically_known_true", + "statically_known_false", + "guard_size_oblivious", + "check_consistent", + "compute_unbacked_bindings", + "ConvertIntKey", + "rebind_unbacked", + "resolve_unbacked_bindings", + "is_accessor_node", + "ValueRangesSLoc", + "SymIntEqByExpr", + "Specialization", +] + +# FX node metadata keys for symbolic shape FX graph. +SHAPEENV_EVENT_KEY = "shapeenv_event" +CURRENT_NODE_KEY = "current_node" + + +def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None: + log.debug( + "lru_cache_stats %s: %s", + wrapped_f.__name__, # type: ignore[attr-defined] + wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined] + ) + + +# Note about Sympy Expr/SympyBoolean/Basic typing: the Sympy hierarchy is +# +# Basic +# Expr +# SympyBoolean +# Relational +# +# Notably, Expr and SympyBoolean are not related. So use Basic when the +# expression could denote int, float OR bool, and otherwise use the more +# specific Expr for int/float and SympyBoolean for bool. +# +# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime. +# So make sure only type checker evaluates this alias. +# Xref: https://www.internalfb.com/diff/D53324783 +SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean" + + +_T = TypeVar("_T") +_SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic) + + +class SymIntEqByExpr: + """ + This is a wrapper around SymInt which has alternative semantics for + equality. Specifically, instead of erroring or guarding, we + instead will hash/compare equality based on the underlying sympy + expression; e.g., s0 and s1 will always compare as False. + + NB: This does NOT do fancy analysis that maybe_evaluate_static does; + we can only reason through equalities that occur because to expressions + canonicalize to the same expression via regular simplification. + """ + + val: Union[torch.SymInt, int] + + def __init__(self, val: Union[torch.SymInt, int]) -> None: + self.val = val + + def __repr__(self) -> str: + return repr(self.val) + + def _extract(self) -> sympy.Expr: + if isinstance(self.val, torch.SymInt): + return self.val.node.expr + else: + return sympy.Integer(self.val) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SymIntEqByExpr) + + # int equality fastpath + if type(self.val) is int and type(other.val) is int: + return self.val == other.val + + return self._extract() == other._extract() + + def __hash__(self) -> int: + return hash(self._extract()) + + +def _nested_int_aware_sort( + tup: tuple[IntLikeType, int], +) -> tuple[int, IntLikeType, int]: + return ( + # Order nested ints by their coefficients. + # 1 here to order nested ints after non-nested-ints. + (1, tup[0].node.nested_int_coeff(), tup[1]) + if is_nested_int(tup[0]) + else (0, *tup) + ) + + +# Wrapper on lru_cache that reports statistics at process end +def lru_cache( + maxsize: Optional[int], +) -> Callable[[Callable[..., _T]], functools._lru_cache_wrapper[_T]]: + def inner(f: Callable[..., _T]) -> functools._lru_cache_wrapper[_T]: + wrapped_f = functools.lru_cache(maxsize)(f) + old_cache_clear = wrapped_f.cache_clear + prev_hits = 0 + prev_misses = 0 + + # TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info + # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not + # weakref'able on some versions of Python + + def cumulative_cache_info() -> functools._CacheInfo: + cur = wrapped_f.cache_info() + return functools._CacheInfo( + prev_hits + cur.hits, + prev_misses + cur.misses, + cur.maxsize, + cur.currsize, + ) + + def new_cache_clear() -> None: + nonlocal prev_hits, prev_misses + cur = wrapped_f.cache_info() + prev_hits += cur.hits + prev_misses += cur.misses + old_cache_clear() + + wrapped_f.cache_clear = new_cache_clear # type: ignore[attr-defined, method-assign] + wrapped_f.cumulative_cache_info = cumulative_cache_info # type: ignore[attr-defined, method-assign] + if log.isEnabledFor(logging.DEBUG): + atexit.register(log_lru_cache_stats, wrapped_f) # type: ignore[arg-type] + return wrapped_f + + return inner + + +# These are modules that contain generic code for interacting with ShapeEnv +# which are unlikely to identify a particular interesting guard statement +@lru_cache(None) +def uninteresting_files() -> set[str]: + import torch._compile + import torch._dynamo.eval_frame + import torch._inductor.sizevars + import torch._library.custom_ops + import torch._library.fake_impl + import torch._logging + import torch._subclasses.fake_tensor + import torch._subclasses.meta_utils + + mods = [ + sys.modules[__name__], + torch.fx.experimental.recording, + torch.fx.experimental.sym_node, + torch.fx.interpreter, + torch, + torch._compile, + torch._dynamo.eval_frame, + torch._inductor.sizevars, + torch._library.custom_ops, + torch._library.fake_impl, + torch._subclasses.meta_utils, + torch._subclasses.fake_tensor, + torch._logging._internal, + torch._logging.structured, + ] + import torch._dynamo.guards + + return ( + {inspect.getfile(m) for m in mods} + | torch._dynamo.guards.uninteresting_files() + | {""} + ) + + +class ConstraintViolationError(RuntimeError): + pass + + +def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool: + return elem._has_symbolic_sizes_strides + + +Int: TypeAlias = Union[torch.SymInt, int] + + +def create_contiguous(shape: Sequence[Int]) -> list[Int]: + strides: list[Int] = [1] + for dim in reversed(shape[:-1]): + strides.append(dim * strides[-1]) # type: ignore[operator] + return list(reversed(strides)) + + +def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int: + """ + Retrieve the hint for an int (based on the underlying real values as observed + at runtime). If no hint is available (e.g., because data dependent shapes), + if fallback is not None, use that instead (otherwise raise an error). + """ + if isinstance(a, torch.SymInt): + return a.node.require_hint(fallback) + assert type(a) is int, a + return a + + +Scalar: TypeAlias = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] + + +def has_hint(a: Scalar) -> bool: + if isinstance(a, SymTypes): + return a.node.has_hint() + return True + + +def is_concrete_int(a: IntLikeType) -> bool: + """ + Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or int): Object to test if it int + """ + assert isinstance(a, (SymInt, int)) + + if isinstance(a, int): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Integer): + return True + + return False + + +def is_concrete_float(a: FloatLikeType) -> bool: + r"""Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or float): Object to test if it float + """ + assert isinstance(a, (SymFloat, float)) + + if isinstance(a, float): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Float): + return True + + return False + + +def is_concrete_bool(a: BoolLikeType) -> bool: + """ + Utility to check if underlying object + in SymBool is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymBool or bool): Object to test if it bool + """ + assert isinstance(a, (SymBool, bool)) + + if isinstance(a, bool): + return True + + if isinstance( + a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse) + ): + return True + + return False + + +def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> bool: + """ + User-code friendly utility to check if a value is static or dynamic. + Returns true if given a constant, or a symbolic expression with a fixed value. + + Args: + a (Union[SymBool, SymFloat, SymInt, bool, float, int]): Object to test + """ + assert isinstance(a, BoolLike + FloatLike + IntLike) + if ( + isinstance(a, BoolLike) + and is_concrete_bool(a) # type: ignore[arg-type] + or isinstance(a, FloatLike) + and is_concrete_float(a) # type: ignore[arg-type] + or isinstance(a, IntLike) + and is_concrete_int(a) # type: ignore[arg-type] + ): + return True + + assert isinstance(a, py_sym_types) + return a.node.shape_env.bound_sympy(a.node.expr).is_singleton() # type: ignore[union-attr] + + +def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool: + """ + Perform a guard on a symbolic boolean expression in a size oblivious way. + This is typically used when a non-oblivious test would result in a guard + on a data dependent value of which we don't know the value of at compile time. + When a guard is tested this way, we may diverge in behavior from how regular + PyTorch semantics would treat it. For more information, see + https://github.com/pytorch/pytorch/pull/118579 + """ + if isinstance(expr, torch.SymBool): + return expr.node.guard_size_oblivious("", 0) + else: + assert isinstance(expr, bool), expr + return expr + + +def check_consistent(new: _T, old: _T) -> None: + """ + Test that two "meta" values (typically either Tensor or SymInt) have + the same values, e.g., after retracing. If we don't understand the + quantities in question, we'll just skip the consistency check. + """ + # TODO: do boolean equality test too, see + # https://github.com/pytorch/pytorch/issues/124110 + scalar_types = (torch.SymInt, torch.SymFloat, int, float) + + if isinstance(new, torch.Tensor): + assert isinstance(old, torch.Tensor) + torch._check( + old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)" + ) + # Do this manually so that each individual test is irrefutable + # (TODO: should be a helper for this, maybe sym_eq? That + # gives us a compound expression and I'm not sure it + # simplifies right now) + for i, j in zip(old.shape, new.shape): + torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") + # NB: bool is subclass of int + elif isinstance(new, scalar_types) and not isinstance(new, bool): + assert isinstance(old, scalar_types) and not isinstance(old, bool), ( + f"{old} != {new}" + ) + torch._check(old == new, lambda: f"{old} != {new} (old != new)") + + +def resolve_unbacked_bindings( + shape_env: Optional[ShapeEnv], + bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: + """ + When we do fake tensor prop, we oftentimes will allocate new unbacked symints. + We then run proxy tensor mode, which populates node.meta["unbacked_bindings"] + with these new symints. To ensure consistency we use PropagateUnbackedSymInts + to rename unbacked bindings to their old ones. But all of the node metas are + still using the old bindings from before the renaming. This function helps to + post facto apply any renamings discovered in the PropogateUnbackedSymInts pass. + """ + if bindings is None: + return None + assert shape_env is not None + return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()} + + +Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]] + + +def rebind_unbacked( + shape_env: Optional[ShapeEnv], n: torch.fx.Node, result: Result +) -> None: + """ + Suppose we are retracing a pre-existing FX graph that previously had + fake tensor propagation (and therefore unbacked SymInts). When we retrace, + we re-propagate fake tensors, which results in new unbacked SymInts. + When this happens, we need to tell the shape environment about the equivalence + of the old and new unbacked SymInts. Pass us the old torch.fx.Node (which + has the old binding information) and the new result (which we can extract the + new unbacked SymInts out from). + """ + + # Inputs never need rebinding + if n.op == "placeholder": + return + + if bindings := resolve_unbacked_bindings( + shape_env, n.meta.get("unbacked_bindings") + ): + assert shape_env is not None + for raw_u0, path in bindings.items(): + u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. + # There are two situations this can happen. + # + # First, you might have a runtime assert that causes the + # constant-ification. In this case, the /binding/ itself will + # still be an unbacked symbol (because we will only force it + # to be a constant later in fake tensor propagation). In this + # case, u1 is a SymInt and we still do all our work as normal. + # + # But second, it might be that fake tensor propagation DIRECTLY + # converted the unbacked SymInt into a constant. This happens + # more rarely, but we have identified two situations it can + # validly occur: + # + # - If you have a tensor_version operator, these are initially + # allocated as unbacked SymInts, but after AOTAutograd they + # get forced specialized to specific values. In this case, + # there is no reason to do runtime asserts on them, this is + # just a hack to properly keep track of them to start. + # + # - If you have an item() call on a constant tensor, the result + # of the item() call is constant and we do not need runtime + # asserts on this symbol. In + # https://github.com/pytorch/pytorch/issues/140625 we have a + # case where in the initial trace of the program we are unable + # to determine that torch.tensor is constant, but then + # subsequent passes cause torch.tensor to become a constant and + # then the unbacked symbol goes poof. + # + # In all of these cases, it is no longer necessary to generate + # deferred runtime asserts, since other subsystems (e.g., the + # constant-ification pass) ensure that the quantity is now truly + # static and cannot change at runtime. So it's OK to discard + # in these situations. + # + # There is one more hazard (re + # https://github.com/pytorch/pytorch/issues/141248), the problem + # is that you can end up with "dangling" unbacked symbols that + # exist in the ShapeEnv but are never bound anywhere. You might + # like an invariant that unbacked symbols never get lost. But + # we do not have this invariant, so do not try to enforce it. + if isinstance(u1, int): + log.info( + "rebind_unbacked: discard %s %s %s -> %s", + n.target, + raw_u0, + path, + u1, + ) + continue + + # We only care about rebinding unbacked things + if u1.node.hint is not None: + continue + + raw_u1 = u1.node.expr + # Simplify SymBool binding + if ( + isinstance(raw_u1, sympy.Piecewise) + and len(raw_u1.args) == 2 + and ( + raw_u1_args0 := cast( + tuple[sympy.Basic, sympy.Basic], raw_u1.args[0] + ) + ) + and raw_u1_args0[0] == 1 + and isinstance(eq := raw_u1_args0[1], sympy.Eq) + and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) + and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) + and eq.rhs == 1 + and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True) + ): + # This is what the pattern match above is testing + repacked = _sympy_cast_symbool_to_symint_guardless( + sympy.Eq(new_raw_u1, 1) + ) + assert repacked == raw_u1, f"{repacked} != {raw_u1}" + # Cancel the to_int(to_bool(x)). This is sound because x in + # [0, 1] + raw_u1 = new_raw_u1 + + if not isinstance(raw_u1, sympy.Symbol): + assert not raw_u1.free_symbols, ( + f"should have been constant, but got {raw_u1}" + ) + continue + + # The old and new could be the same if you improperly hit the memo + # while retracing. Make sure you updated FakeTensorMode.epoch + assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster" + # Reuse the OLD symbol name + shape_env._rename_unbacked_to(raw_u1, raw_u0) + + +# NB: You could try to expand this to cover more cases by simply +# detecting whenever you have an int output, but this is a bit +# dangerous in case someone adds a function that returns an int but is +# mutating. So manually whitelist for now. +def is_accessor_node(node: torch.fx.Node) -> bool: + """ + Helper function to determine if a node is trying to access + a symbolic integer such as size, stride, offset or item. Currently + primarily only used in a DCE pass to figure out purity. + """ + + # Dynamo only exercised condition + if ( + node.op == "call_method" + and isinstance(node.args[0], torch.fx.Node) + and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) + and node.target in ["size", "stride", "storage_offset", "item"] + ): + return True + + if node.op == "call_function" and node.target in [ + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.default, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_stride.default, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_storage_offset, + torch.ops.aten.sym_storage_offset.default, + torch.ops.aten.sym_numel.default, + ]: + return True + + return False + + +def canonicalize_bool_expr(expr: _T) -> _T: + """ + Canonicalize a boolean expression by transforming it into a lt / le + inequality and moving all the non-constant terms to the rhs. + We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr + recursively + nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 + + Args: + expr (sympy.Expr): Expression to canonicalize + """ + # Canonicalise an inequality by transforming it into a lt / le + # inequality and moving all the non-constant terms to the rhs + # We canonicalise And / Ors / Not via cnf + # nb. Relational.canonical in sympy is broken + # https://github.com/sympy/sympy/issues/25924 + + if not isinstance( + expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne) + ): + return expr + + if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): + expr = sympy.logic.boolalg.to_cnf(expr) + return _canonicalize_bool_expr_impl(expr) # type: ignore[arg-type, return-value] + + +def _sympy_from_args( + cls: type[Union[sympy.Add, sympy.Mul]], + args: list[sympy.Expr], + sort: bool = True, + is_commutative: Optional[bool] = None, +) -> sympy.Expr: + """ + Create a sympy expression from a list of arguments, optimizing for performance. + + This function creates a sympy Add or Mul expression from a list of arguments + while avoiding expensive operations like flattening. It handles sorting the + arguments appropriately based on the expression type. + + Args: + cls: The sympy class to create (Add or Mul) + args: List of sympy expressions to combine + sort: Whether to sort the arguments (default: True) + is_commutative: Whether the operation is commutative (default: None) + + Returns: + A sympy expression of type cls combining all arguments + + Raises: + ValueError: If cls is not sympy.Add or sympy.Mul + """ + + if not args: + return cls.identity # type: ignore[union-attr] + + # These args are already in canonical form, so we avoid calling + # Add(*args) to avoid expensive Add.flatten operation + if sort: + if cls is sympy.Add: + sort_fn = sympy.core.add._addsort + elif cls is sympy.Mul: + sort_fn = sympy.core.mul._mulsort + else: + raise ValueError(f"Unknown cls: {cls}") + + # we don't support non commutative with sort + assert is_commutative is True + if args[0].is_Number: + rest = args[1:] + sort_fn(rest) + return cls._from_args([args[0]] + rest, is_commutative=is_commutative) # type: ignore[attr-defined] + else: + args = args.copy() + sort_fn(args) + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] + else: + # if the args are already sorted, we create directly + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] + + +def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: + """ + After canonicalization, we are guaranteed to have eliminated Ge/Gt relations + (rewriting them to Le/Lt, respectively). + """ + if isinstance(expr, (sympy.And, sympy.Or)): + return type(expr)(*map(canonicalize_bool_expr, expr.args)) + + opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} + t: Union[type[Any]] + if isinstance(expr, tuple(opposite.keys())): + rhs = expr.lhs - expr.rhs # type: ignore[attr-defined] + t = opposite[type(expr)] # type: ignore[index] + else: + assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) + rhs = expr.rhs - expr.lhs + t = type(expr) + + def is_neg(t: sympy.Expr) -> bool: + return (t.is_Number and t.is_negative) or ( + isinstance(t, sympy.Mul) and t.args[0].is_Number and t.args[0].is_negative + ) + + lhs = S.Zero + rhs = _reduce_to_lowest_terms(rhs) + if isinstance(rhs, sympy.Add): + pos = [] + neg = [] + for term in rhs.args: + if is_neg(term): + neg.append(-term) + else: + pos.append(term) + # these are already sorted + rhs = _sympy_from_args(sympy.Add, pos, sort=False, is_commutative=True) + # the terms were changed, so needs a sorting + lhs = _sympy_from_args(sympy.Add, neg, sort=True, is_commutative=True) + elif is_neg(rhs): + # lhs == 0 + lhs, rhs = -rhs, S.Zero + # We don't have to evaluate here because lhs, rhs came from a Boolean + # and it was already simplified + return t(lhs, rhs, evaluate=False) + + +def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: + """ + Eliminates any integer factor from a given expression. + E.g., 6x + 4y reduces to 3x + 2y. + + Useful when an expression is == or != to 0. + """ + + def integer_coefficient(x: sympy.Expr) -> int: + if x.is_Integer: + return abs(int(x)) + elif x.is_Mul: + # If one of the args of a Mul is an Integer, it is the + # first arg. eg: args(2*x*3*y) == (6, x, y) + return abs(int(x.args[0])) if x.args[0].is_Integer else 1 # type: ignore[call-overload] + else: + return 1 + + def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr: + if x.is_Integer: + return x / factor + elif x.is_Mul: + if x.args[0] != factor: + args = [x.args[0] / sympy.Integer(factor), *x.args[1:]] + else: + # Mul._from_args require a canonical list of args + # so we remove the first arg (x.args[0] / factor) if it was 1 + args = list(x.args[1:]) + return _sympy_from_args(sympy.Mul, args, is_commutative=x.is_commutative) + else: + raise AssertionError(f"illegal arg to div_by_factor: {x}") + + if expr.is_Add: + atoms = cast(Sequence[sympy.Expr], expr.args) + factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) + if factor == 1: + return expr + atoms = [div_by_factor(x, factor) for x in atoms] + return _sympy_from_args( + sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative + ) + elif expr.is_Integer: + return S.One + elif expr.is_Mul: + return div_by_factor(expr, integer_coefficient(expr)) + return expr + + +def is_nested_int(s: IntLikeType) -> TypeGuard[SymInt]: + return isinstance(s, torch.SymInt) and s.node.is_nested_int() + + +IterateExprsAtom: TypeAlias = Union[ + SymInt, SymFloat, SymBool, int, float, bool, sympy.Basic, torch.Tensor +] +IterateExprs: TypeAlias = Union[IterateExprsAtom, Sequence[IterateExprsAtom]] + + +def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: + """ + Recursively iterate through a value and yield all sympy expressions contained within it. + + This function traverses various data structures (tensors, lists, tuples, etc.) and extracts + any symbolic expressions they contain. It's used for operations like finding free symbols + in complex nested structures. + + Args: + val: The value to extract sympy expressions from. Can be a symbolic type (SymInt, SymFloat, SymBool), + a sympy expression, a primitive type (int, float, bool), a container (tuple, list), + a sparse tensor, a regular tensor, None, or a torch.Generator. + + Yields: + sympy.Basic: Each sympy expression found in the value. + + Raises: + AssertionError: If the value is of an unsupported type. + """ + if isinstance(val, SymTypes): + # This allow applies to the jagged layout NestedTensor case as + # nested ints are not symbolic + if is_symbolic(val): + yield val.node.expr + elif isinstance(val, sympy.Basic): + yield val + elif isinstance(val, (int, float, bool)): + pass + elif isinstance(val, (tuple, list)): + for s in val: + yield from _iterate_exprs(s) + elif is_sparse_any(val): + yield from _iterate_exprs(val.size()) + elif isinstance(val, torch.Tensor): + yield from _iterate_exprs(val.size()) + yield from _iterate_exprs(val.stride()) + yield from _iterate_exprs(val.storage_offset()) + elif val is None: + pass + # see Note: [Generator arguments in AOTDispatcher] + elif isinstance(val, torch.Generator): + pass + else: + raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") + + +def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]: + """ + Recursively collect all free symbols from a value. + + This function traverses various data structures (tensors, lists, tuples, etc.) and extracts + all sympy symbols contained within them. It's useful for finding all symbolic variables + that a complex nested structure depends on. + + Args: + val: The value to extract symbols from. Can be a symbolic type (SymInt, SymFloat, SymBool), + a container (tuple, list), a tensor, or None. + + Returns: + OrderedSet[sympy.Symbol]: An ordered set of all free symbols found in the value. + """ + if val is None: + return OrderedSet() + + itr = _iterate_exprs(val) + + # we need at least 1 to call union, so we hand code the identity + try: + first_expr = next(itr) + except StopIteration: + return OrderedSet() + + # TODO: Apparently, returning an OrderedSet here breaks + # python test/distributed/tensor/test_dtensor_compile.py TestDTensorCompile.test_dtensor_dynamic + return first_expr.free_symbols.union(*(e.free_symbols for e in itr)) # type: ignore[return-value] + + +def has_free_symbols(val: IterateExprs) -> bool: + """Faster version of bool(free_symbols(val))""" + return not all((e.is_number or e.is_Boolean) for e in _iterate_exprs(val)) + + +def has_free_unbacked_symbols(x: IterateExprs) -> bool: + """Faster version of bool(free_unbacked_symbols(val))""" + from sympy.core.traversal import iterargs + + for s in _iterate_exprs(x): + for arg in iterargs(s): + if arg.is_Symbol and symbol_is_type( + arg, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT) + ): + return True + return False + + +def free_unbacked_symbols(x: IterateExprs) -> OrderedSet[sympy.Symbol]: + """Like free_symbols, but filtered to only report unbacked symbols""" + + # NB: keep synced with is_unbacked_symint + return OrderedSet( + s + for s in free_symbols(x) + if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)) + ) + + +# WARNING: Don't use this on Dynamo produced graphs, they don't have meta +# setup! +def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]: + """ + Check if a given FX node is a symbol binding node. + + A symbol binding node is one that has a SymInt value in its meta that contains + a sympy Symbol expression, and is either a placeholder node or contains unbacked symbols. + + Args: + node (torch.fx.Node): The FX node to check + + Returns: + Optional[sympy.Symbol]: The sympy Symbol if the node is a symbol binding node, None otherwise + """ + if ( + "val" in node.meta + and isinstance(node.meta["val"], torch.SymInt) + and isinstance(node.meta["val"].node.expr, sympy.Symbol) + and ( + node.op == "placeholder" + or free_unbacked_symbols(node.meta["val"].node.expr) + ) + ): + return node.meta["val"].node.expr + return None + + +def find_symbol_binding_fx_nodes( + graph: torch.fx.Graph, +) -> dict[sympy.Symbol, torch.fx.Node]: + """ + Find all nodes in an FX graph that bind sympy Symbols. + + This function scans through all nodes in the given FX graph and identifies + nodes that bind sympy Symbols (typically placeholder nodes with SymInt values). + When multiple nodes bind the same symbol, only the first occurrence is kept. + + Args: + graph: The FX graph to search for symbol binding nodes + + Returns: + A dictionary mapping from sympy Symbols to their binding FX nodes + """ + r = {} + # NB: Prefer first occurrence of symbol + for node in graph.nodes: + if (s := is_symbol_binding_fx_node(node)) is not None and s not in r: + r[s] = node + return r + + +@dataclass(frozen=True) +class Specialization: + """ + This class is used in multi-graph compilation contexts where we generate + multiple specialized graphs and dispatch to the appropriate one at runtime. + This allows us to optimize the trade-off between performance and generality + by creating specialized versions for common patterns (e.g., x.shape[0] % 16 == 0) + while maintaining a general fallback. + """ + + source: TensorPropertySource + check_fn: Callable + + +# Analogous to ConvertIntSource +@dataclass(frozen=True) +class ConvertIntKey: + def __str__(self) -> str: + return ".cast_symbool_to_symint_guardless()" + + def get(self, b: bool) -> IntLikeType: + """Get the int value from bool""" + return cast_symbool_to_symint_guardless(b) + + +@dataclass(frozen=True) +class CallMethodKey: + name: str + + def __str__(self) -> str: + return f".{self.name}()" + + def get(self, o: Any) -> Any: + """Call the method on object""" + return getattr(o, self.name)() + + +@dataclass(frozen=True) +class InnerTensorKey: + inner_name: str + + def __str__(self) -> str: + return f".{self.inner_name}" + + def get(self, o: Any) -> Any: + """Get the inner tensor attribute""" + return getattr(o, self.inner_name) + + +@dataclass(frozen=True) +class DivideByKey: + divisor: IntLikeType + + def __str__(self) -> str: + return f".__floordiv__({self.divisor})" + + def get(self, o: int) -> int: + """Divide object by divisor""" + return o // self.divisor + + +def _free_unbacked_symbols_with_path( + a: object, + path: pytree.KeyPath, + real: Optional[object] = None, + shape_env: Optional[ShapeEnv] = None, + pending: Optional[set[sympy.Symbol]] = None, + simplify: bool = False, +) -> dict[sympy.Symbol, pytree.KeyPath]: + """ + Recursively traverses a structure to find unbacked symbols and their access paths. + + This function walks through tensors, lists, tuples, and symbolic values to locate + unbacked symbols that are in the pending set, and returns a mapping from those + symbols to their access paths in the structure. + + Args: + a: The object to traverse (tensor, list, tuple, SymInt, etc.) + path: The current path in the object tree + real: Optional real tensor corresponding to the fake tensor being traversed + shape_env: Optional ShapeEnv to register unbacked values with + pending: Set of unbacked symbols to look for (will be modified in-place) + simplify: Whether to use simplified expressions + + Returns: + A dictionary mapping unbacked symbols to their access paths + """ + go = functools.partial( + _free_unbacked_symbols_with_path, + shape_env=shape_env, + pending=pending, + simplify=simplify, + ) + + def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr: + if simplify: + return s.node.expr + # (When called from compute_unbacked_bindings) + # NB: Intentionally access _expr, not expr, do not want + # simplification! + return s.node._expr + + if pending is None: + pending = set() + r = {} + if isinstance(a, (tuple, list)): + # NB: real is apparently not always a tuple/list here + # python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu + for i in range(len(a)): + r.update( + go( + a[i], + path + (pytree.SequenceKey(i),), + real=real[i] if real is not None else None, # type: ignore[index] + ) + ) + elif is_traceable_wrapper_subclass(a): + # TODO: Determine if this is correct + attrs, _ = a.__tensor_flatten__() + for attr in attrs: + sub = getattr(a, attr) + r.update(go(sub, path + (InnerTensorKey(attr),))) + elif isinstance(a, torch.Tensor): + from torch._subclasses.fake_tensor import FakeTensor + + assert isinstance(a, FakeTensor) + r.update( + go( + a.size(), + path + (CallMethodKey("size"),), + real=a.real_tensor.size() if a.real_tensor is not None else None, + ) + ) + if a.layout not in [ + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + ]: + r.update( + go( + a.stride(), + path + (CallMethodKey("stride"),), + real=a.real_tensor.stride() if a.real_tensor is not None else None, + ) + ) + r.update( + go( + a.storage_offset(), + path + (CallMethodKey("storage_offset"),), + real=( + a.real_tensor.storage_offset() + if a.real_tensor is not None + else None + ), + ) + ) + + elif ( + isinstance(a, (torch.SymInt, torch.SymFloat)) + and isinstance(s := expr(a), sympy.Symbol) + and s in pending + ): + r[s] = path + if shape_env and real is not None: + assert isinstance(real, (int, float)) + shape_env.set_unbacked_var_to_val(s, real) + pending.remove(s) + # When an unbacked SymInt is perfectly divisible by an integer + # constant, we replace it with the integer constant to improve + # reasoning capabilities. However, in synthetic examples, it is + # then possible that the factor never is explicitly allocated. + # Fortunately, we can compute it by division. + elif ( + isinstance(a, torch.SymInt) + and isinstance(s := expr(a), sympy.Mul) + and len(s.args) == 2 + and isinstance(lhs := s.args[0], (sympy.Integer, sympy.Symbol)) + and isinstance(rhs := s.args[1], sympy.Symbol) + # support exactly one unbacked for now + and ((rhs in pending) ^ (lhs in pending)) + # support constant coefficient or backed symbolic coefficient + and ( + isinstance(coeff := lhs if lhs not in pending else rhs, sympy.Integer) + or shape_env + and coeff in shape_env.var_to_val + ) + ): + + def _symint_wrap(s: sympy.Symbol) -> SymInt: + return shape_env.create_symintnode( # type: ignore[union-attr] + s, + hint=int(shape_env.var_to_val[s]), # type: ignore[union-attr] + source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr] + ) + + unbacked = lhs if lhs in pending else rhs + divisor: IntLikeType = ( + int(coeff) + if shape_env and isinstance(coeff, sympy.Integer) + else _symint_wrap(coeff) + ) + # TODO: DivideByKey needs to test divisibility at runtime! + r[unbacked] = path + (DivideByKey(divisor),) + if real is not None: + assert isinstance(real, int) + val = ( + real // int(coeff) + if isinstance(coeff, sympy.Integer) + else CleanDiv(real, coeff) + ) + if shape_env: + shape_env.set_unbacked_var_to_val(unbacked, val) + pending.remove(unbacked) + # The annoyance here arises from the fact that SymBool is + # allocated by allocating a SymInt and then testing if it's equal + # to one. So you have a complicated binding site logic for this. + elif ( + isinstance(a, torch.SymBool) + and isinstance(s := expr(a), sympy.Eq) + # This must match create_unbacked_symbool EXACTLY + and isinstance(s.lhs, sympy.Symbol) + and s.rhs == 1 + and s.lhs in pending + ): + r[s.lhs] = path + (ConvertIntKey(),) + if real is not None: + assert type(real) is bool + if shape_env: + shape_env.set_unbacked_var_to_val(s, int(real)) + pending.remove(s.lhs) + + return r + + +def compute_unbacked_bindings( + shape_env: Optional[ShapeEnv], + example_value: object, + old_example_value: Optional[object] = None, + peek: bool = False, +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: + """ + After having run fake tensor propagation and producing example_value + result, traverse example_value looking for freshly bound unbacked + symbols and record their paths for later. It is an error if + we have allocated an unbacked SymInt but it cannot be found in + example_value. (NB: this means if you have a multi-output + function, you must call this on the tuple of tensor output, you + cannot wait!) + + The peek parameter lets you check out what the bindings are without + changing the affected list. This is primarily useful for ensuring + unbacked_var_to_val is promptly populated when propagate_real_tensors is on. + """ + if shape_env is None: + return None + + fs = shape_env.pending_fresh_unbacked_symbols + pending = set(fs) + if not pending: + return None + + if not peek: + log.info("compute_unbacked_bindings %s", fs) + fs.clear() + + symbol_to_path = _free_unbacked_symbols_with_path( + example_value, (), shape_env=shape_env, pending=pending, simplify=False + ) + if not peek and pending: + extra = ( + repr((example_value.stride(), example_value.storage_offset())) + if isinstance(example_value, torch.Tensor) + else "" + ) + raise PendingUnbackedSymbolNotFound( + f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n" + "Did you accidentally call new_dynamic_size() or item() more times " + "than you needed to in your fake implementation?\n" + "For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit" + ) + + # Why do we have to do some rebinding here? If the original FX node + # wasn't a binding site because you had a memo hit, but post + # translation you aren't a memo hit anymore, there's now a new binding + # site... but we know (because it's the same FX node) that the value + # is actually the same, they're just not obviously equal anymore. + # + # The logic here is written carefully, because unlike the + # bind_unbacked case, we are not guaranteed to have a symbol for + # old_sym. If we have a symbol, do regular rename unbacked to; but if + # we don't, we need to specially eliminate the fresh unbacked symbol + # (NB: we are /trusting/ that the memoization is correct, and that we + # don't need to generate a new runtime assert. This is load bearing, + # as repropagation can happen after we've frozen runtime asserts.) + if old_example_value is not None: + for keypath in symbol_to_path.values(): + old_sym = pytree.key_get(old_example_value, keypath) + new_sym = pytree.key_get(example_value, keypath) + if isinstance(new_sym, SymTypes) and isinstance( + new_s := new_sym.node.expr, sympy.Symbol + ): + if ( + isinstance(old_sym, SymTypes) + and (old_s := old_sym.node.expr) != new_s + ): + if isinstance(old_s, sympy.Symbol): + shape_env._rename_unbacked_to(new_s, old_s) + else: + shape_env._eliminate_unbacked(new_s, old_s) + elif not isinstance(old_sym, SymTypes): + shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym)) + + return symbol_to_path + + +# Note [guard_or_] +# The following two functions are common utilities used while defining unbacked semantics +# of various framework code. Those would be used in situations you prefer to guard and know +# the result of the expression over not guarding, but in case you hit a data dependent error +# you are ok with just returning true or false. +# +# When to use this? +# (1) If you can use a higher level combinator prefer using those instead, they are definitely safe (modulo short-circuiting). +# +# (2) It can be used if the program would behave equivalently if _guard_or returned true or false. +# Many inductor optimizations fall in this bracket for example. +# +# (3) Finally, it's even be OK if the program wouldn't behave equivalently, so long as the +# change is semantics preserving. It can be semantics preserving if the program errors in more +# cases than it did previously (but otherwise behaves identically), or if it changes some quantity +# in a way that doesn't matter (e.g., strides often fall in this bucket.) +# +# (4) Specialize for the general case and add a runtime assertion that would fail during +# runtime if the conditions for the general case are not satisfied. Examples for this are; +# assuming expand/reshape inputs are not -1. or assuming the non-broadcasting path. +# +def _guard_or(a: BoolLikeType, default: bool) -> bool: + """ + Try to guard a, if data dependent error encountered just return default. + """ + if not isinstance(a, SymBool): + assert isinstance(a, bool) + return a + + # if backed_size_oblivious is True we treat backed as unbacked here. + if torch.fx.experimental._config.backed_size_oblivious: + result = _static_eval_sym_bool(a) + return result if result is not None else default + + shape_env = getattr(a.node, "shape_env", None) + + # xla symnode path. + if shape_env is None: + return guard_bool(a) + + sym_node = a.node + r = sym_node.shape_env.evaluate_sym_node( + sym_node, size_oblivious=False, fallback_value=default + ) + return bool(r) + + +def guard_or_false(a: BoolLikeType) -> bool: + """ + Try to guard a, if data dependent error encountered just return false. + """ + return _guard_or(a, False) + + +def guard_or_true(a: BoolLikeType) -> bool: + """ + Try to guard a, if data dependent error encountered just return true. + """ + return _guard_or(a, True) + + +def _static_eval_sym_bool(x: SymBool) -> Optional[bool]: + assert isinstance(x, SymBool) + expr = x.node.expr + + try: + # Shape env access is inside the try on purpose. xla symnode does not + # have it on its attributes. + shape_env = x.node.shape_env + simplified = shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + else: + return None + except Exception: + log.debug("Could not simplify %s", expr) + return None + + +def statically_known_false(x: BoolLikeType) -> bool: + """ + Returns True if x can be simplified to a constant and is False. + If x cannot be evaluated from static, we return False + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to False at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + """ + if not isinstance(x, SymBool): + assert isinstance(x, bool) + return not x + + result = _static_eval_sym_bool(x) + if result is None: + return False + + return not result + + +def statically_known_true(x: BoolLikeType) -> bool: + """ + Returns True if x can be simplified to a constant and is true. + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to true at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + """ + if not isinstance(x, SymBool): + assert isinstance(x, bool) + return x + + result = _static_eval_sym_bool(x) + if result is None: + return False + + return result + + +def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: + """ + and, but for symbolic expressions, without bool casting. + """ + if len(others) == 0: + return x + for y in others: + x = operator.and_(x, y) + return x + + +def sym_eq(x: _T, y: _T) -> BoolLikeType: + """ + Like ==, but when run on list/tuple, it will recursively test equality + and use sym_and to join the results together, without guarding. + """ + if isinstance(x, (tuple, list)) and isinstance(y, (list, tuple)): + if len(x) != len(y): + return False + return functools.reduce(operator.and_, map(sym_eq, x, y), True) + elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)): + return x == y + else: + raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}") + + +def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: + """ + or, but for symbolic expressions, without bool casting. + """ + if len(others) == 0: + return x + for y in others: + x = operator.or_(x, y) + return x + + +def guard_scalar( + a: Union[SymBool, SymInt, SymFloat, int, bool, float], +) -> Union[bool, int, float]: + """ + Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float. + + This function dispatches to the appropriate guard function based on the type of the input. + + Args: + a: A symbolic or concrete scalar value (bool, int, or float) + + Returns: + The concrete value after guarding + + Raises: + AssertionError: If the input is not a recognized scalar type + """ + if isinstance(a, (SymBool, bool)): + return guard_bool(a) + elif isinstance(a, (SymInt, int)): + return guard_int(a) + elif isinstance(a, (SymFloat, float)): + return guard_float(a) + else: + raise AssertionError(f"unrecognized scalar {a}") + + +def _advise_is_size(a: SymInt) -> None: + """ + Don't use this directly; use torch._check_is_size instead. + + This is a softer version of _constrain_range_for_size (with min=0, + max=Inf). Instead of forcibly constraining a variable (and erroring if we + failed to constrain it), it will simply advise us that a size is + constrained in some way. We will always defer a runtime assert for this + constraint if we cannot prove it at compile-time, but we we only + *sometimes* learn useful extra information at compile-time with this + information. This is in contrast to constrain_range_for_size, where if + you don't call that on a fresh unbacked symint, chances are we will choke. + + TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed + code. Right now this is only really used in code with AOTAutograd trace + through, so it is not a big problem that this isn't supported, but in + principle all of this code should be Dynamo'able too. + + TODO: I didn't support min/max because I didn't have a use case where this + actually helped. In principle we can support it, it just makes the + implementation below more complicated. + """ + + # This must always succeed, because the sole allowed caller _check_is_size + # was responsible for expect_true'ing this + # This assert triggers expensive sym compute, do not do it until its cheap. + # assert a >= 0 + + # NB: it's important not to constrain range for size for *hinted* SymInts, + # because it is not only unsound, it will immediately trip our asserts + # that hints have to be consistent with static analysis! If you somehow + # have an unbounded SymInt that later constrains to 1, this will be + # inconsistent with the range + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env.is_unbacked_symint(a.node.expr) + ): + _constrain_range_for_size(a) + + +def _advise_is_bounded(a: SymInt, upper_bound: IntLikeType) -> None: + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env.is_unbacked_symint(a.node.expr) + and isinstance(upper_bound, int) # TODO: relax + ): + a.node.shape_env._constrain_is_bounded(a.node.expr, upper_bound) + + +def _constrain_range_for_size( + a: SymInt, min: Optional[int] = None, max: Optional[int] = None +) -> None: + """ + This function is NOT INTENDED to be used by itself. + """ + + if isinstance(a, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat/SymBool is nyi") + + assert isinstance(a, SymInt), "can only constrain range for SymInt" + assert isinstance(a.node.expr, sympy.Symbol), f"constraining non-Symbols NYI: {a}" + + a.node.shape_env._constrain_range_for_size(a.node.expr, min, max) + + +# inclusive both ways +def constrain_range( + a: SymInt, *, min: Optional[int], max: Optional[int] = None +) -> None: + """ + Applies a constraint that the passed in SymInt must lie between min-max + inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning + that it can be used on unbacked SymInts). If min/max are None, we assume + that the dimension is unbounded in that direction. Repeated application + of constrain_range intersects the ranges. This is a fairly low level API + that doesn't have a lot of safety guarantees (TODO: provide higher level + APIs). + + Currently, we use this API in the following circumstance: when we allocate + an unbacked SymInt, denoting an integer quantity which is data dependent, + we ordinarily do not know anything about what values it may take. This + means that any sort of guard on it will immediately fail. However, in + many cases, we know something about the unbacked SymInt: for example, we + know that nonzero(x).size(0) must be >= 0. We use constrain_range to + narrow the possible range, declaring that negative symbols are impossible. + This permits to definitely answer True to queries like 'nnz >= 0', even if + we don't know what the actual (hinted) value of 'nnz' is. In fact, we + actually use constrain_range to unsoundly discharge common guards: for an + unbacked SymInt produced by nonzero, we will also assume that it is not + equal to 0/1 (even though these are perfectly possible values at runtime), + because we generally expect graphs that are valid for N=2 to also be valid + for N=1. + """ + if min is None: + min = -int_oo + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + if isinstance(a, int): + if not (min <= a <= max): + raise ValueError(f"Invalid value {a} for range [{min}:{max}]") + return + + a.node.shape_env._constrain_range(a.node.expr, min, max) + + +def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None: + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + return + else: + shape_env = b.node.shape_env + else: + shape_env = a.node.shape_env + + shape_env._constrain_unify(a, b) + + +# Assume that a boolean is true for the purposes of subsequent symbolic +# reasoning. This will keep track of corresponding runtime checks to verify +# that the result is upheld: either as a regular guard, or as a special set +# of asserts which are triggered when an unbacked SymInt is allocated. +# +# DO NOT use this function for these cases: +# +# - This is inappropriate for "branching" conditions (where both +# true and false result in valid programs). We will always assume +# the condition evaluates true, and so it will never be possible +# to trace the false condition when you use it. For true branching +# on unbacked SymInts, you must use torch.cond; if you incorrectly +# use expect_true in this case, you will make the false branch +# unreachable (as we will simply assume that only the true branch +# is ever exercised). +# +# - This is inappropriate for situations where you know some other system +# invariant guarantees that this property holds, since you don't +# really need to insert a runtime check in that case. Use something +# like constrain_range in that case. +# +# This API has a hitch. To avoid having to reimplement error reporting +# capabilities, this function CAN return False. The invariant is that +# the surrounding code must raise an error when this function returns +# False. This is quite low level, so we recommend using other functions +# like check() which enforce this in a more intuitive way. +# +# By the way, this name is a nod to the __builtin_expect macro, +# which is used similarly (but unlike __builtin_expect, you MUST fail +# in the unlikely branch.) (I think expect is a good name; in recent +# versions of C++, this is replaced with [[likely]], which is weaker +# and not accurate for this function!) +def expect_true(a: BoolLikeType, skip: int = 0) -> bool: + if isinstance(a, SymBool): + # TODO: check perf implications of this + frame = inspect.currentframe() + for _ in range(skip + 1): # always run this loop at least once + if frame is None: + break + frame = frame.f_back + return a.node.expect_true( + frame.f_code.co_filename if frame else "", frame.f_lineno if frame else 0 + ) + assert type(a) is bool, a + return a + + +def guard_bool(a: BoolLikeType) -> bool: + if isinstance(a, SymBool): + return a.node.guard_bool("", 0) # NB: uses Python backtrace + assert type(a) is bool, a + return a + + +def guard_int(a: IntLikeType) -> int: + if isinstance(a, SymInt): + return a.node.guard_int("", 0) # NB: uses Python backtrace + assert type(a) is int, a + return a + + +def guard_float(a: FloatLikeType) -> float: + if isinstance(a, SymFloat): + return a.node.guard_float("", 0) # NB: uses Python backtrace + assert isinstance(a, float), a + return a + + +# Given a GraphModule, return all the FakeTensors for all the placeholders +def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]: + return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"] + + +def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]: + return [n.target for n in gm.graph.nodes if n.op == "placeholder"] + + +# Given a GraphModule and arguments to run it with, evaluate that the guards +# for its associated ShapeEnv are satisfied by the passed arguments. This +# WILL check for duck sizing. +def eval_guards( + gm: torch.fx.GraphModule, *args: Tensor, ignore_static: bool = True +) -> bool: + return gm.shape_env.evaluate_guards_for_args( # type: ignore[operator, union-attr] + fx_placeholder_vals(gm), args, ignore_static=ignore_static + ) + + +def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]: + return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr] + + +class DimDynamic(Enum): + """ + Controls how to perform symbol allocation for a dimension. It is always + sound to default this to DYNAMIC, but the policies DUCK and STATIC can + result in better trace-time and compile-time performance, as they reduce + the number of allocated symbols and generally make your graph more static. + + NB: If we notice you've applied a constraint to the dimension, we will + force it to DYNAMIC for simplicity. + + DimDynamic is controlled by a variety of higher level UX features. + Currently: + + - In eager mode, the default policy is DUCK. + - The default is changed to STATIC with assume_static_by_default. + - An individual dim is marked DYNAMIC if you mark_dynamic_dim. + - In export mode, the default policy is STATIC. + - An individual dim is marked DYNAMIC if you specify it in + dynamic_shapes passed to export. + """ + + # Treat the dimension symbolically + DYNAMIC = 0 + # Treat the dimension symbolically, but if its hint matches another + # dynamic dimension, unify the two symbols ("duck sizing") + DUCK = 1 + # Treat the dimension statically based on its hint + STATIC = 2 + # Treat the dimension as a size-like unbacked + SIZE_LIKE_UNBACKED = 3 + # Infer the strides from stride. If size is static, strides will be static as well. + INFER_STRIDE = 4 + # Like SIZE_LIKE_UNBACKED, but there's a hint + OBLIVIOUS_SIZE = 5 + + +# NB: These constraints affect both clients and backends: given some +# constraint C, the client must pass inputs that satisfy the constraint, +# while a backend must not introduce guards BEYOND this constraint. +# For clarity, we document the implications on both sides for both the client +# and the backend. +# +# NB: These constraints are on a *single* dimension. In principle, we could +# also have multi-dimension constraints, but our guess is that this is not +# actually useful and so we are not supporting it right now. +# +# NB: Strict constraints are typically only suitable for export, as in eager +# a backend like inductor may validly introduce extra, discretionary guards +# to improve performance of code. A StrictMinMaxConstraint would be brittle +# under future optimizations performed by inductor; we don't guarantee +# eager code with StrictMinMaxConstraint will keep working in the future! + + +@dataclass(frozen=True) +class Constraint: + warn_only: bool + + +@dataclass(frozen=True) +class StrictMinMaxConstraint(Constraint): + """ + For clients: the size at this dimension must be within 'vr' (which + specifies a lower and upper bound, inclusive-inclusive) AND it + must be non-negative and should not be 0 or 1 (but see NB below). + + For backends: there must not be any guards on this dimension which + are not implied by the given lower and upper bound. Regardless of + the lower bound, the backend can assume the size is non-negative + and that it is not 0 or 1. + + An unbounded StrictMinMaxConstraint can be thought of as a strict version + of "RelaxedUnspecConstraint". + + NB: Export will often unsoundly assume that a graph works for 0/1, even + though at trace time we assumed size is not 0 or 1. The idea is that + if we produce a graph that works for a range of values, it will be OK + for N=0/1 too. + """ + + vr: ValueRanges + + def render(self, source: Source) -> str: + """Format the constrain equation""" + # TODO: better printing for -oo and oo + return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}" + + +@dataclass(frozen=True) +class RelaxedUnspecConstraint(Constraint): + """ + For clients: no explicit constraint; constraint is whatever is implicitly + inferred by guards from tracing. + + For backends: there must exist at least TWO possible values for the + size at this dimension which satisfy the guards for this dimension. + + In other words, this constraint helps us distinguish between "we don't + care if this dimension specializes or not" versus "this dimension must be + unspecialized." However, this constraint doesn't say very much about what + specialization is permitted; for example, if we guard on a size being + even, this would still be acceptable under an unspec constraint. This + makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler + may add constraints to otherwise dynamic dimensions; we can't assert that + there are NO guards as this is brittle because compilers should be able to + add extra constraints. If you want to assert that there are no guards, + use StrictMinMaxConstraint with an unbounded ValueRanges. + """ + + def render(self, source: Source) -> str: + return f"RelaxedUnspecConstraint({source.name()})" + + +# NB: None here indicates the client constraint is whatever is implicitly +# inferred by guards from tracing, and that a backend can add whatever guards +# it wants (including fully specializing the value). +DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None] + + +@dataclass(frozen=True) +class EqualityConstraint(Constraint): + """ + Represent and decide various kinds of equality constraints between input sources. + + A "source pair" is a pair of input sources for dynamic dimensions that + are specified equal. We represent `source_pairs` in a union-find forest + so that we can efficiently check whether two such sources are transitively equal. + + A "derived equality" relates an input source to an expression over a root. + The root can be another input source, corresponding to some dynamic dimension, + or a phantom symbol that does not directly represent any dynamic dimension. We + represent `derived_equalities` involving input sources in a transitively-closed map + so that we can efficiently check whether an input source is transitively equal to + a given expression over another input source. + (NOTE: In contrast, it is easy to decide whether an input source is transitively equal + to a given expression over a phantom symbol; such expressions are already in canonical + form and so the problem reduces to symbolic expression equality.) + """ + + source_pairs: list[tuple[Source, Source]] + derived_equalities: list[ + tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]] + ] + phantom_symbols: list[sympy.Symbol] + relaxed_sources: set[Source] + + _parents: dict[Source, Source] = field(init=False) + _defs: dict[Source, sympy.Expr] = field(init=False) + + def __post_init__(self) -> None: + """ + Pre-processing to answer queries `is_equal` and `is_derived` below. + + Example: Suppose we are given: + source_pairs [a = b, b = c] + derived_equalities [d = c + 1, e = d - 1] + We first construct a union find with source_pairs: + _parents = {a: a, b: a, c: a} + Then we compute canonical symbolic expressions, recursively applying derived_equalities + until we bottom out: + _defs = {d: c + 1, e: (c + 1) - 1 aka c} + """ + + # self._parents is a map from input sources to input sources where, conceptually, + # these are directed edges in a union-find forest + _parents: dict[Source, Source] = {} + object.__setattr__(self, "_parents", _parents) + # self._defs is a map from input sources to "canonical" symbolic expressions, + # i.e., unary expressions with symbols that corresponds to regular Dims (i.e., + # not derived Dims) + _defs: dict[Source, sympy.Expr] = {} + object.__setattr__(self, "_defs", _defs) + + for source1, source2 in self.source_pairs: + # preprocess into a union-find forest + self._union(self._find(source1), self._find(source2)) + for source, root, fn in self.derived_equalities: + # preprocess into a transitively-closed map + # NOTE(avik): we reuse the union-find forest for canonicalizing input sources + if isinstance(root, sympy.Symbol): + self._defs[self._find(source)] = fn(root) + else: + self._defs[self._find(source)] = fn(self._rewrite(root)) + + def _find(self, source: Source) -> Source: + # chase edges to find the root of this equivalence class + if source in self._parents: + return self._find(self._parents[source]) + else: + return source + + def _union(self, root1: Source, root2: Source) -> None: + # merge two equivalence classes by adding an edge from one root to the other + if root1 != root2: + self._parents[root1] = root2 + + def _rewrite(self, src: Source) -> sympy.Expr: + # always represent the given source by the root of its equivalence class + src = self._find(src) + if src in self._defs: + # simply look up the definition if it exists + # NOTE(avik): This works because definitions are always transitively-closed; + # otherwise we would have to do recursive rewriting. + return self._defs[src] + else: + # otherwise, create a symbol representing the source + return sympy.Symbol(src.name()) + + def is_equal(self, source1: Source, source2: Source) -> bool: + return ( + # check whether source1 and source2 have the same root + # or are relaxed + (src1 := self._find(source1)) in self.relaxed_sources + or (src2 := self._find(source2)) in self.relaxed_sources + or src1 == src2 + # check whether source1 is derived equal to source2 + or self.is_derived(source1, source2, lambda x: x) + ) + + def is_derived( + self, src: Source, symbol_src: Source, fn: Callable[[sympy.Expr], sympy.Expr] + ) -> bool: + # check whether both src and symbol_src have the same definition + return self._rewrite(src) == fn(self._rewrite(symbol_src)) + + +def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]: + assert isinstance(symbolic_context, SymbolicContext), ( + "Invalid symbolic_context object" + ) + assert type(symbolic_context) is not SymbolicContext, ( + "Illegal usage of symbolic_context ABC" + ) + return True + + +def _is_supported_equivalence(expr: sympy.Expr) -> bool: + # Currently supported Dim ops are linear expressions with integer coefficients. + # So check that expr only contains +, *, ints, and a single occurrence of a symbol. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(expr, (sympy.Add, sympy.Mul)): + if len(expr.args) > 2: + return False + lhs, rhs = expr.args + return (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or ( + isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs) + ) + return isinstance(expr, sympy.Symbol) + + +def _has_uninterpretable_sympy_function(expr: sympy.Basic) -> bool: + """ + Add functions that our sympy interpreter can't reify into FX nodes + """ + return expr.has( + torch.utils._sympy.functions.ToFloat, + torch.utils._sympy.functions.TruncToInt, + torch.utils._sympy.functions.CeilToInt, + ) + + +@dataclass(frozen=True) +class SymbolicContext: + """ + Data structure specifying how we should create symbols in + ``create_symbolic_sizes_strides_storage_offset``; e.g., should + they be static or dynamic. + + This is an abstract base class because we are probably going to add + another version of this that says "use exactly these SymInts, don't + allocate fresh symbols." + """ + + +@dataclass(frozen=True) +class SymIntSymbolicContext(SymbolicContext): + """ + Data structure specifying any constraints on a SymInt input + """ + + constraint: DimConstraint + + +_P1 = ParamSpec("_P1") +_T1 = TypeVar("_T1") + + +@dataclass(frozen=True) +class StatelessSymbolicContext(Generic[_P1, _T1], SymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. + This will cause fresh symbols to be allocated + """ + + dynamic_sizes: DimList[DimDynamic] + dynamic_strides: DimList[DimDynamic] = None # type: ignore[assignment] + constraint_sizes: DimList[DimConstraint] = None # type: ignore[assignment] + constraint_strides: DimList[DimConstraint] = None # type: ignore[assignment] + specialize_on: Optional[list[list[Callable[_P1, _T1]]]] = None + # If the tensor is a view, this should be populated for the base. It contains + # information on how to allocate symbols when recursively fakeifying the base + # during view fake-ification. + view_base_context: Optional[SymbolicContext] = None + # TODO: add storage offset and stride symbolic_context + + def __post_init__(self) -> None: + if self.specialize_on is None: + object.__setattr__( + self, + "specialize_on", + [[]] * len(self.dynamic_sizes), + ) + if self.dynamic_strides is None: + object.__setattr__( + self, + "dynamic_strides", + [DimDynamic.INFER_STRIDE] * len(self.dynamic_sizes), + ) + if self.constraint_sizes is None: + object.__setattr__( + self, "constraint_sizes", [None] * len(self.dynamic_sizes) + ) + if self.constraint_strides is None: + object.__setattr__( + self, "constraint_strides", [None] * len(self.dynamic_sizes) + ) + assert all( + stride in (DimDynamic.INFER_STRIDE, DimDynamic.DYNAMIC, DimDynamic.DUCK) + for stride in self.dynamic_strides + ) + + +# note [Tensor Fakification and Symbol Caching] +# +# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. +# The reason we do this is because there are certain classes of operations, namely, +# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor +# state at the end of a dynamo trace is different than the fake tensor state at the beginning +# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, +# view relationships, etc. +# +# As we create a new fake mode, we also lose the memoization that comes with it. Rather than +# transfer the memoization cache, we instead transfer the shape env. However, with this +# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in +# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across +# recompilations. +# +# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass +# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. +# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is +# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors +# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env +# is used. +# TODO(voz): Shape env validation +@dataclass(frozen=True) +class StatefulSymbolicContext(StatelessSymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by a cache of Source:Symbol. A cache hit + will reuse a stored symbol, and a cache miss will write to this cache. + + This behaves like StatelessSymbolicContext, except the cache supersedes the + other values - dynamic_sizes and constraint_sizes will not be read if we cache + hit. + + It is the cache owner's responsibility to maintain the lifecycle of the cache + with respect to different shape_envs, clearing, etc. + """ + + tensor_source: Source = None # type: ignore[assignment] + # Why is this keyed on int first? + # That integer is actually the id of the shape_env. This cache short-circuits symbol + # creation, and we must store it per shape env. Now, while tracing invariants are a single + # shape env per tracing context, and every new frame gets a new shape_env. So where would we have + # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events + # is invoked, and creates a new shape_env. Replaying events against this new shape_env will + # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never + # get recorded in var_to_val, etc. + # TODO(voz): consider a weakref to the shape_env here + shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + super().__post_init__() + # The None default is annoying, but required because of dataclass limitations + assert self.tensor_source is not None + if not self.shape_env_to_source_to_symbol_cache: + object.__setattr__(self, "shape_env_to_source_to_symbol_cache", {}) + + +@dataclass(frozen=True) +class SubclassSymbolicContext(StatefulSymbolicContext): + """ + The correct symbolic context for a given inner tensor of a traceable tensor subclass + may differ from that of the outer symbolic context. This structure allows for this + flexibility, with inner symbolic contexts mapped via attr -> symbolic context. + """ + + inner_contexts: dict[str, SymbolicContext] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + super().__post_init__() + if self.inner_contexts is None: + self.inner_contexts = {} + + +@dataclass +class TrackedFake: + """ + Tracks the sources of all fake tensors we wrap in Dynamo. + Used by shape guard computation. + """ + + fake: Union[FakeTensor, SymInt] + source: Source + symbolic_context: Optional[SymbolicContext] + + def __hash__(self) -> int: + return hash((self.fake, self.source.name())) + + def __eq__(self, other: object) -> bool: + if isinstance(other, TrackedFake): + return self.fake is other.fake and self.source.name() == other.source.name() + return False + + +def is_symbolic( + val: Union[int, SymInt, float, SymFloat, bool, SymBool], +) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]: + if isinstance(val, (int, float, bool)): + return False + return val.node.is_symbolic() + + +IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) + + +def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]: + """ + Expand products of sums into sums of products. + + This function takes a list of sympy expressions and separates them into + additive expressions (those with is_Add=True) and other expressions. + It then computes the distributive product, expanding (a+b)*(c+d) into a*c + a*d + b*c + b*d. + + Args: + args: A list of sympy expressions to expand + + Returns: + A tuple containing: + - The expanded expression as a sympy.Expr + - A boolean indicating whether expansion occurred (True if multiple additive + expressions were present or if there was at least one additive and one other expression) + """ + adds, other = [], [] + for arg in args: + if arg.is_Add: + adds.append(arg) + else: + other.append(arg) + + result = [sympy.Mul(*other)] + for add in adds: + result = [a * b for a, b in itertools.product(result, add.args)] + + result = sympy.Add(*result) + return result, len(adds) > 1 or (len(adds) > 0 and len(other) > 0) + + +def _fast_expand(expr: _SympyT) -> _SympyT: + """ + A faster implementation of sympy's expand function for common cases. + + This function expands expressions like (a+b)^n or (a+b)*(c+d) into sums of products, + but avoids the expensive checks and features of sympy's full expand implementation. + It only recreates objects when necessary to avoid expensive operations. + + Args: + expr: A sympy expression to expand + + Returns: + The expanded expression + """ + + # The expand algorithm in sympy is slow due to all the features is supports + # For eg: e^(-x)*(x-1)/(x+1) is expanded to (x-1)/(e^x + e^x*x) if x is + # positive and (e^(-x)*x-e^(-x))/(x+1) if x is negative. We do not implement + # such features here to avoid expensive checks. We also make sure that we + # only re-create the objects if any of the args changed to avoid expensive + # checks when re-creating objects. + new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] + if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): + return _fast_expand(expr.func(*new_args)) + + if expr.is_Pow: + base: sympy.Expr + exp: sympy.Expr + base, exp = expr.args # type: ignore[assignment] + if exp.is_Integer and base.is_Add: + if exp > 1: + return sympy.expand_multinomial(expr, deep=False) + elif exp < 0: + return S.One / sympy.expand_multinomial(S.One / expr, deep=False) + elif expr.is_Mul: + num: list[sympy.Expr] = [] + den: list[sympy.Expr] = [] + for arg in expr.args: + if arg.is_Pow and arg.args[1] == -1: + den.append(S.One / arg) # type: ignore[operator, arg-type] + else: + num.append(arg) # type: ignore[arg-type] + + num, num_changed = _expandsums(num) + den, den_changed = _expandsums(den) + if num_changed or den_changed: + return num / den + + return expr + + +@lru_cache(256) +def safe_expand(r: _SympyT) -> _SympyT: + """ + Expand the given symbolic expression by recursively rewriting product of + sums into sum of products (with the product being either a multiplication or + exponentiation). + + NOTE: using this on an intermediate expression may prevent simplification + down the line, e.g., if we eagerly expand `(a + b)^2` into `a^2 + 2ab + b^2`, + we won't be able to simplify `(a^2 + 2ab + b^2) / (a + b)` as easily. + """ + if hasattr(r, "expand"): + try: + return _fast_expand(r) + except RecursionError: + log.warning("RecursionError in _fast_expand(%s)", r) + return r + else: + return r + + +class _SymbolInfo(NamedTuple): + k: sympy.Symbol + vr: Optional[ValueRanges] + val: Optional[sympy.Integer] + is_size_like: bool + + +@lru_cache(None) +def _maybe_evaluate_static_worker( + expr: _SympyT, + # NB: this is a tuple to ensure it can be LRU cached + symbol_info: tuple[_SymbolInfo, ...], + unbacked_only: bool, + size_oblivious: bool, +) -> Optional[_SympyT]: + """ + This variant of ShapeEnv._maybe_evaluate_static has no dependence on + ShapeEnv and thus can be cached indefinitely. It does the "heavy" lifting + for static evaluation, including nontrivial reliance on Sympy simplification + that occurs when we reallocate the symbols + """ + + # Simplify making use of value range lower bound + new_shape_env = {} + new_range_env = {} + for idx, sinfo in enumerate(symbol_info): + k, vr, val, is_size_like = sinfo + if isinstance(val, SingletonInt): + # Skip var_ranges logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + assert vr is not None + if size_oblivious and is_size_like: + lower = max(2, vr.lower) + # Clamping size-oblivious to some quantity below sys.maxsize + # helps us determine that f(u0) != sys.maxsize, which is a + # test that is looking for sys.maxsize as a sentinel, but you + # don't really want to worry about it for unbacked SymInts. + # This is similar to the flavor where size oblivious omits + # 0/1, it changes semantics but in a benign way. + upper = min(2**48, vr.upper) + # Excluding the very upper bound can be helpful + if upper > lower: + upper = upper - 1 + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= upper: + vr = ValueRanges(lower, upper) + else: + lower = vr.lower + # Don't do anything if we don't have a nontrivial lower bound + # Also don't do anything if we asked only to simplify unbacked + # SymInt + if lower is -int_oo or (unbacked_only and val is not None) or not vr.is_int: + new_range_env[k] = vr + continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # + # Positive means >= 1 + # Positive - 1 means >= 0 + # Positive + lower - 1 means >= lower + # The new symbol 's' is "too low", so when we substitute it in + # we have to increase it by offset (and conversely, the new + # variables have to have their value range bounds adjusted as + # well) + s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True) + + # Note: + # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers. + # Sympy might give unexepected results when comparing an integer with a non-integer + # Therefore, we cast offset to int here. + # For example: + # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True) + # expr = sympy.Eq(shape_0 - 1/3, 4) + # expr.xreplace({}) # False + offset = int(lower - 1) + new_shape_env[k] = s + offset + new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset) + + # TODO: remove this try catch (esp for unbacked_only) + try: + new_expr = expr.xreplace(new_shape_env) + except RecursionError: + log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) + return None + + # We need to canonicalize, as after expand we may have something like `a + b = a` and + # sympy will not simplify the a. The two appeareances of the a will then make value ranges + # analysis give lose bounds + new_expr = canonicalize_bool_expr(safe_expand(new_expr)) + if new_expr.is_number: + return new_expr + + # Check if the range can solve it statically + out = bound_sympy(new_expr, new_range_env) + if out.is_singleton(): + return out.lower + + return new_expr if unbacked_only else None + + +def error() -> NoReturn: + raise AssertionError("shouldn't be hit") + + +# TODO: Deduplicate this with torch/_prims_common/__init__.py +def eval_is_non_overlapping_and_dense( + sizes: Sequence[int], strides: Sequence[int] +) -> int: + return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides))) + + +def _eval_is_non_overlapping_and_dense( + sizes: Sequence[int], strides: Sequence[int] +) -> bool: + """ + Evaluates whether a tensor with the given sizes and strides is non-overlapping and dense. + + A tensor is non-overlapping if there's no memory location that belongs to more than one element. + A tensor is dense if all elements are stored in memory without gaps. + + Args: + sizes: Sequence of dimension sizes for the tensor + strides: Sequence of strides for the tensor + + Returns: + True if the tensor is non-overlapping and dense, False otherwise + """ + dim = len(sizes) + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + # or it is a 0/1 element tensor + if dim == 1: + return strides[0] == 1 or sizes[0] < 2 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + lengths_and_strides = sorted(zip(sizes, strides), key=operator.itemgetter(1)) + + # Unlike the C++ code, we don't move the 0/1 size dimensions to the + # end. So we have to keep going for this code. + expected_stride = 1 + for length, stride in lengths_and_strides: + if length == 1: + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr: + return sympy.Piecewise((1, x), (0, True)) + + +def cast_symbool_to_symint_guardless( + symbool: Union[bool, torch.SymBool], +) -> Union[int, torch.SymInt]: + """ + Converts a SymBool or bool to a SymInt or int without introducing guards. + + This function maps True to 1 and False to 0, preserving the symbolic nature + of the input when it's a SymBool. Unlike regular casting which might introduce + guards, this function performs the conversion without adding any guards. + + Args: + symbool: A boolean value, either a concrete bool or symbolic SymBool + + Returns: + The corresponding integer value (1 for True, 0 for False) as either + a concrete int or symbolic SymInt + """ + if isinstance(symbool, bool): + return 1 if symbool else 0 + int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) + return symbool.node.shape_env.create_symintnode( + int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None + ) + + +SYMPY_INTERP = { + "IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense, + "cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless, + "math": math, + "torch": torch, +} + + +def _lru_cache( + fn: Callable[..., _T], maxsize: Optional[int] = None +) -> functools._lru_cache_wrapper[_T]: + """ + Wrapper around lru_cache that clears when new info about shapes has been + updated. + + Use lru_cache if the output is always the same, regardless of the + constraints we know now (i.e. evaluate_expr) + + Use _lru_cache otherwise. + + Also note that this depends on _update_version_counter being called on the + shape environment whenever the constraints are updated, otherwise the cache + will not be cleared. + """ + fn_cache = lru_cache(maxsize)(fn) + prior_version = 0 + + if config.validate_shape_env_version_key: + prior_key = None + + @functools.wraps(fn) + def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: + nonlocal prior_version, prior_key + if prior_key is None: + prior_key = self._get_key() + + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + prior_key = self._get_key() + else: + assert prior_key == self._get_key(), ( + "ShapeEnv cache key changed without version being updated!" + ) + + return fn_cache(self, *args, **kwargs) + + else: + + @functools.wraps(fn) + def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: # type: ignore[misc] + nonlocal prior_version + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + + return fn_cache(self, *args, **kwargs) + + wrapper.cache_clear = fn_cache.cache_clear # type: ignore[attr-defined] + wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +@dataclass(frozen=True) +class RuntimeAssert: + """ + This is pretty similar to ShapeGuard but it also comes with a message, + and is exclusively used for things that MUST be true (unlike guards, + which can evaluate False, in which case you just choose not to use + a particular specialization) + """ + + expr: SympyBoolean + msg: str = field(repr=False) + stack: CapturedTraceback = field(repr=False) + + +# Used for printing SymExprs in compile_fx +class SymExprPrinter(PythonPrinter): + def _print_Float(self, expr: sympy.Float) -> str: + return str(float(expr)) + + +class _ShapeGuardPrinter(abc.ABC): + """ + Abstract base class for printers that convert symbolic expressions to string representations. + + This class provides common functionality for printing symbolic expressions with + special handling for symbols that represent tensor shapes, strides, etc. + Subclasses implement specific formatting for different output languages. + + Args: + symbol_to_source: Mapping from sympy symbols to their source objects + source_ref: Function to convert a source to its string representation + var_to_sources: Mapping from sympy symbols to their source objects (for error reporting) + """ + + def __init__( + self, + symbol_to_source: Mapping[sympy.Symbol, list[Source]], + source_ref: Callable[[Source], str], + var_to_sources: Mapping[sympy.Symbol, list[Source]], + ) -> None: + self.symbol_to_source = symbol_to_source + self.source_ref = source_ref + self.var_to_sources = var_to_sources + super().__init__() + + def _print_Float(self, expr: sympy.Float) -> str: + """Convert a sympy Float to a Python float string representation.""" + return str(float(expr)) + + def _print_Symbol(self, expr: sympy.Symbol) -> str: + """ + Convert a sympy Symbol to its source representation. + + This method looks up the symbol in symbol_to_source mapping and returns + the string representation of its first source. + + Args: + expr: The sympy Symbol to convert + + Returns: + String representation of the symbol's source + + Raises: + AssertionError: If the symbol is not found in symbol_to_source + """ + assert isinstance(expr, sympy.Symbol), str(type(expr)) + + def repr_symbol_to_source() -> str: + return repr( + { + symbol: [s.name() for s in sources] + for symbol, sources in self.symbol_to_source.items() + } + ) + + assert self.symbol_to_source.get(expr), ( + f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) " + f"not in {repr_symbol_to_source()}. If this assert is failing, it could be " + "due to the issue described in https://github.com/pytorch/pytorch/pull/90665" + ) + return self.print_source(self.symbol_to_source[expr][0]) + + @abc.abstractmethod + def print_source(self, source: Source) -> str: + """ + Convert a source object to its string representation. + + Args: + source: The source object to convert + + Returns: + String representation of the source + """ + ... + + @abc.abstractmethod + def doprint(self, expr: sympy.Expr) -> str: + """ + Convert a sympy expression to its string representation. + + Args: + expr: The sympy expression to convert + + Returns: + String representation of the expression + """ + ... + + +class ShapeGuardPythonPrinter(_ShapeGuardPrinter, PythonPrinter): + """ + Python printer for shape guards that extends the base ShapeGuardPrinter. + + This class provides functionality to print symbolic expressions as Python code, + with caching to improve performance when printing the same expressions multiple times. + It handles printing of sources and expressions according to Python syntax. + + Args: + *args: Arguments passed to the parent classes. + """ + + def __init__(self, *args: Any) -> None: + super().__init__(*args) + self._print_cache: dict[sympy.Expr, str] = {} + + def print_source(self, source: Source) -> str: + """ + Convert a source object to its string representation using the source_ref function. + + Args: + source: The source object to convert + + Returns: + String representation of the source + """ + return self.source_ref(source) + + def doprint(self, expr: sympy.Expr) -> str: + """ + Convert a sympy expression to its Python string representation with caching. + + This method first checks if the expression is already in the cache. + If found, it returns the cached result; otherwise, it delegates to + PythonPrinter's doprint method and caches the result. + + Args: + expr: The sympy expression to convert + + Returns: + String representation of the expression in Python syntax + """ + val = self._print_cache.get(expr, None) + if val is not None: + return val + else: + res = PythonPrinter.doprint(self, expr) + self._print_cache[expr] = res + return res + + +@deprecated( + "`torch.fx.experimental.symbolic_shapes.ShapeGuardPrinter` is deprecated, " + "please use `torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter` instead.", + category=FutureWarning, +) +class ShapeGuardPrinter(ShapeGuardPythonPrinter): + pass + + +class _ShapeGuardCppPrinter(_ShapeGuardPrinter, CppPrinter): + def __init__(self, *args: Any) -> None: + self.all_symbols: set[str] = set() + self.source_to_symbol: dict[Source, sympy.Symbol] = {} + super().__init__(*args) + + def print_source(self, source: Source) -> str: + if source in self.source_to_symbol: + return self.source_to_symbol[source].name + + source_name = source.name() + mangled_name = re.sub("[^0-9a-zA-Z_]+", "_", source_name) + old_mangled_name = mangled_name + count = 0 + while mangled_name in self.all_symbols: + mangled_name = f"{old_mangled_name}_{count}" + count += 1 + self.source_to_symbol[source] = sympy.Symbol(mangled_name) + self.all_symbols.add(mangled_name) + return mangled_name + + def doprint(self, expr: sympy.Expr) -> str: + return CppPrinter.doprint(self, expr) + + +# A dataclass for storing shape guards +@dataclass(frozen=True) +class _ShapeGuardsHelper: + exprs: list[str] + + +# A dataclass for storing C++ expressions and helper variables +@dataclass(frozen=True) +class _CppShapeGuardsHelper(_ShapeGuardsHelper): + source_to_symbol: dict[Source, sympy.Symbol] + + +class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter): + def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]): + super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) + + +class DynamicDimConstraintPrinter(PythonPrinter): + """ + Printer for dynamic dim constraints. + - Instead of symbol s_k it prints its source t.size()[i] + - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc. + + We use this to suggest code for specifying dynamic dim constraints. + """ + + def __init__( + self, + symbol_to_source: dict[sympy.Symbol, list[Source]], + source_name_to_debug_name: Mapping[str, str], + ): + super().__init__() + self.symbol_to_source = symbol_to_source + self.source_name_to_debug_name = source_name_to_debug_name + + def _print_Symbol(self, expr: sympy.Symbol) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + assert self.symbol_to_source.get(expr), ( + f"Unknown symbol {expr} created by constraints solver" + ) + return self.symbol_to_source[expr][0].name() + + +class DimConstraints: + """ + Custom solver for a system of constraints on symbolic dimensions. + Solutions are "static" values or simplified "dynamic" constraints. + """ + + def __init__( + self, + symbol_to_source: dict[sympy.Symbol, list[Source]], + var_to_val: Mapping[sympy.Symbol, sympy.Integer], + marked_dynamic: set[sympy.Symbol], + source_name_to_debug_name: Mapping[str, str], + ) -> None: + # We try to solve systems of inequalities with 1 free variable. + self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = ( + defaultdict(set) + ) + # Among them, we prioritize solving for a free variable that has equalities. + # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() + # and removing a symbol from the former => removing it from the latter. + self._symbols_with_equalities: set[sympy.Symbol] = set() + # A solution of a free variable with equalities becomes a substitution. + # We use these substitutions to simplify other constraints. + # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. + self._substitutions: dict[sympy.Symbol, sympy.Integer] = {} + + # In general, constraints may have // and % operations. + # Of course, // can be expressed in terms of / and %. + # Our inequality solver can handle / but not %. So we need to transform them away. + # We do so by using the values of variables as hints to evaluate %. + # For soundness we record additional congruence guards and solve them separately. + self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val + self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set) + + # We do not try to (directly) solve inequalities with > 1 free variables. + # NOTE: free variables in these inequalities cannot also be in _substitutions. + self._multivariate_inequalities: set[SympyBoolean] = set() + + # We park external equalities between free variables here. + self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = [] + + # Solutions come in two forms: + # - (static) specializations + # - (dynamic) inequalities / congruences + self._static_results: set[str] = set() + self._dynamic_results: set[str] = set() + + # printer for solutions + self._dcp = DynamicDimConstraintPrinter( + symbol_to_source, source_name_to_debug_name + ) + + # inconsistencies found on substituting with concrete values / static solutions + self._inconsistencies: list[str] = [] + + # symbols that are marked dynamic + self._marked_dynamic = marked_dynamic + + # track supported sympy functions and subtract from list of all sympy functions + self._supported_sympy_functions: set[sympy.Function] = { + Application, + Mod, + PythonMod, + FloorDiv, + } + self._enumerate_sympy_functions() + + def rewrite_with_congruences(self, s: sympy.Symbol, expr: _SympyT) -> _SympyT: + """ + Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. + This leaves rational operators (in particular of the form b / d) that our inequality solver can handle. + We solve the added congruences separately (using our congruence solver, see below). + """ + + def mod_handler(*args: sympy.Expr) -> sympy.Expr: + # Suppose that we have an expression of the form b % d with free variable s. + # Using the value of s as a "hint," we can evaluate b % d to a value k. + # Then we can rewrite b % d to k while adding the guard b % d == k. + + # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF + # the original expression always evaluates to a constant value (i.e., it does not vary with s). + # In other words, + # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with + # the original expression; + # - while it may be possible to find solutions of s with the original expression that are not + # solutions with the rewritten expression, in that case the original expression cannot evaluate + # to the same value for all solutions of s. + # + # Should we be worried about this incompleteness? No, because of the following reasons: + # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech + # (i.e., "don't let perfect be the enemy of the good"). + # 2. We already have a tradition of using hints to add guards in the compiler for making progress. + # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards + # we generate (or simplify to) seem to be of the form b % d == k where k is a constant. + # + # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2. + # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we + # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! + base, divisor = args + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( + self._var_to_val + ) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + return mod_reduced + + def floor_div_handler(*args: sympy.Expr) -> sympy.Expr: + # Suppose that we have an expression of the form b // d with free variable s. + # Using the value of s, we can evaluate b % d to a value k. + # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k. + + # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d + # and eliminating b % d as above. + base, divisor = args + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( + self._var_to_val + ) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha + return (base - mod_reduced) / divisor + + if expr.has(Mod): + expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) + if expr.has(FloorDiv): + expr = expr.replace(FloorDiv, floor_div_handler) + return expr + + def _enumerate_sympy_functions(self) -> None: + module = torch.utils._sympy.functions + all_functions = set() + for attr in dir(module): + if isinstance(func := getattr(module, attr), sympy.FunctionClass): + all_functions.add(func) + self._unsupported_sympy_functions = all_functions.difference( + self._supported_sympy_functions + ) + + def _has_unsupported_sympy_function(self, expr: sympy.Basic) -> bool: + """ + Tracks list of sympy.Functions the export solver doesn't know how to handle. + """ + return expr.has(*self._unsupported_sympy_functions) + + def add(self, expr: SympyBoolean) -> bool: + """Add an expression to the set of constraints. + + Return whether the expression is a trivial constraint (i.e., an obvious tautology). + """ + if expr == sympy.true: + return True + orig_expr = expr + orig_reduced = orig_expr.xreplace(self._var_to_val) + # TODO(avik): https://github.com/pytorch/pytorch/issues/101093 + # It is possible that `expr` will fail the consistency check because of + # precision errors. Specifically, on substituting its free symbols with + # their concrete values, we might end up comparing floats. Until we have + # a fix for this issue, we delay raising such failures. See solve(). + if orig_reduced == sympy.false: + self._inconsistencies.append(f"{orig_expr} is inconsistent!") + if isinstance( + expr, (sympy.Ne, sympy.Or, sympy.And) + ) or self._has_unsupported_sympy_function(expr): + # we're not going to do anything useful with these, so drop them + return False + free_symbols = expr.free_symbols + assert free_symbols, f"Did not expect constraint with no free variables: {expr}" + if len(free_symbols) > 1: + # multivariate: record and move on + self._multivariate_inequalities.add(expr) + else: + # univariate: can solve these immediately + s = next(iter(free_symbols)) + # eliminate // and % (see documentation of `rewrite_with_congruences` above) + old_n_congruences = len(self._congruences[s]) + expr = self.rewrite_with_congruences(s, expr) + new_n_congruences = len(self._congruences[s]) + if expr == sympy.true: + return old_n_congruences == new_n_congruences + reduced = expr.xreplace(self._var_to_val) + if reduced == sympy.false: + self._inconsistencies.append( + f"{expr}, obtained by rewriting {orig_expr} with congruences, " + "is inconsistent!" + ) + if isinstance(expr, sympy.Eq): + # special status for symbols that have equalities (see `solve` below) + self._symbols_with_equalities.add(s) + self._univariate_inequalities[s].add(expr) + return False + + def add_equality(self, source: Source, expr: sympy.Expr) -> None: + """Add an equality constraint""" + if expr.is_number: + # specialization, right here + self._static_results.add(f"{source.name()} == {expr}") + else: + # these will resolve to either specializations or dynamic equality constraints + self._symbolic_equivalences.append((source, expr)) + + def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]: + reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {} + for s, congruences in self._congruences.items(): + remainder_modulus_pairs = [] + congruences_to_check = set() + for congruence in congruences: + base, divisor = congruence.args + # We are given a congruence of the form base % divisor == 0 with a free variable s. So: + # - we transform this into an equation of the form base = divisor * tmp; + # - we solve this equation for s to get a linear solution with free variable tmp. + tmp = sympy.Symbol("reduce_congruences_tmp", integer=True) + symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s]) + # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear + # for how to interpret the results. + if s == symbol: + # This means the solution is of the form s = modulus*tmp + remainder. + modulus, remainder = sympy.polys.polytools.div(solution, tmp) + if isinstance(modulus, sympy.Integer) and isinstance( + remainder, sympy.Integer + ): + # Make sure 0 <= remainder <= modulus. + remainder = remainder % modulus + remainder_modulus_pairs.append((remainder, modulus)) + continue + # This means that we did not get a unique solution to the equation. + # No problem, we will check it. + congruences_to_check.add(congruence) + # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i). + # The solution will be a congruence of the form s = r mod m. + # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT. + if remainder_modulus_pairs: + remainder, modulus = sympy.ntheory.modular.solve_congruence( + *remainder_modulus_pairs + ) + reduced_congruences[s] = {(s - remainder) % modulus} + substitution = { + s: modulus * sympy.Symbol("tmp", integer=True) + remainder + } + reduced_congruences[s].update( + congruence + for congruence in congruences_to_check + if not sympy.checksol(congruence, substitution) + ) + else: + reduced_congruences[s] = congruences_to_check + + return reduced_congruences + + def _raise_inconsistencies(self) -> None: + if self._inconsistencies: + msg = "\n".join(self._inconsistencies) + self._inconsistencies.clear() + raise ValueError(f"The following inconsistencies were found:\n{msg}") + + def solve(self) -> None: + """Solve the system of constraint equations to find simplified constraints""" + self._raise_inconsistencies() + # as long as there are symbols with equalities, solve for them + # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) + while self._symbols_with_equalities: + s = self._symbols_with_equalities.pop() + exprs = self._univariate_inequalities.pop(s) + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + if isinstance(solution, sympy.And): + solution = next( + (arg for arg in solution.args if isinstance(arg, sympy.Eq)), + solution, + ) + assert isinstance(solution, sympy.Eq), ( + f"Expected an equality constraint for {s}, got {solution}" + ) + symbol, val = solution.args + assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" + # because this is univariate, the solution is a specialization + self._static_results.add( + f"{self._dcp.symbol_to_source[s][0].name()} == {val}" + ) + # add this as a substitution to simplify other constraints + self._substitutions[s] = val # type: ignore[assignment] + + # simplify multivariate inequalities: some of them will now become univariate! + multivariate_inequalities = self._multivariate_inequalities + self._multivariate_inequalities = set() + for expr in multivariate_inequalities: + self.add(expr.xreplace({s: self._substitutions[s]})) + self._raise_inconsistencies() + + # solve linear congruences + # NOTE(avik): We do not need to solve them for symbols that have already been specialized. + reduced_congruences = self._reduce_congruences() + for s, congruences in reduced_congruences.items(): + for congruence in congruences: + # any congruence that cannot be checked becomes a dynamic constraint as well + if s not in self._substitutions or not sympy.checksol( + congruence, {s: self._substitutions[s]} + ): + if self._is_supported_congruence(congruence): + base, divisor = congruence.args + tmp_name = "_" + str( + self._dcp.source_name_to_debug_name.get( + self._dcp.symbol_to_source[s][0].name(), + self._dcp.symbol_to_source[s][0].name(), + ) + ) + tmp = sympy.Symbol(tmp_name, integer=True) + from torch._dynamo.source import ConstantSource + + self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)] + r = try_solve(sympy.Eq(base, divisor * tmp), s) + assert r is not None + self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1]))) + + # remaining symbols have only pure inequalities (no equalities) + for s, exprs in self._univariate_inequalities.items(): + try: + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + # because this is univariate, the solution is a dynamic (range) constraint + if isinstance(solution, sympy.Or): + solution = next( + iter( + arg + for arg in solution.args + if arg.xreplace(self._var_to_val) + ) + ) + if isinstance(solution, sympy.And): + for arg in solution.args: + self._dynamic_results.add(self._dcp.doprint(arg)) + else: + self._dynamic_results.add(self._dcp.doprint(solution)) + except (NotImplementedError, AssertionError) as e: + log.warning("Failed to reduce inequalities: %s", e) + for expr2 in exprs: + self._dynamic_results.add(self._dcp.doprint(expr2)) + + # simplify symbolic equivalences: some of them will now become specializations! + symbolic_equivalences = self._symbolic_equivalences + self._symbolic_equivalences = [] + for source, expr3 in symbolic_equivalences: + self.add_equality(source, expr3.xreplace(self._substitutions)) + + # remaining symbolic equivalences become dynamic equality constraints + for source, expr3 in self._symbolic_equivalences: + self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr3)}") + + @classmethod + def _is_supported_congruence(cls, congruence: sympy.Expr) -> bool: + base, divisor = congruence.args + # Congruences that can be currently expressed with supported Dim ops are + # of the form (x + a) % b == 0, where x is a Dim and a and b are constants. + # This allows us to derive x as b*y - a for some Dim y. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(base, sympy.Add): + lhs, rhs = base.args + cond = ( + isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer) + ) or (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol)) + else: + cond = isinstance(base, sympy.Symbol) + cond = cond and isinstance(divisor, sympy.Integer) + return cond + + def forced_specializations(self) -> dict[str, sympy.Expr]: + """Returns a dictionary of the names of symbols to their specialized value""" + + def debug_name(src: Source) -> str: + name = src.name() + if self._dcp.source_name_to_debug_name: + return f"{self._dcp.source_name_to_debug_name[name]} = {name}" + else: + return name + + return { + debug_name(self._dcp.symbol_to_source[s][0]): val + for s, val in self._substitutions.items() + if s in self._marked_dynamic + } + + def _is_derived_dim( + self, dim: object + ) -> TypeGuard[torch.export.dynamic_shapes._DerivedDim]: + return isinstance(dim, torch.export.dynamic_shapes._DerivedDim) + + def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes.Dim]: + return isinstance(dim, torch.export.dynamic_shapes.Dim) and not isinstance( + dim, torch.export.dynamic_shapes._DerivedDim + ) + + def _process_derived_dim_roots( + self, + results: dict[str, dict[str, Any]], + name_to_dim: dict[str, Any], + ) -> None: + """ + Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots, + and 2) root swapping. + + 1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests + dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final + suggested fixes handle this correctly, but we can get intermediate results that look like + {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}} + and this routine prettifies this by unifying to a single root, and making each suggestion + either a derived dim or min/max range, not both. + + 2) With suggested fixes for derived dims, roots can be swapped, + e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name, + since this leads to messages like "dx - 1 = Dim("dx - 1", ...)". + Instead we evaluate the new root value, and remove results for its derivations. + + First we find all the original roots (specified in dynamic_shapes), that are found in the + values of results (i.e. used for computing suggesting fix values). These original roots + (suppose `dx`) are either specialized, unchanged, refined, or swapped + (expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value + in results, and remove suggestions for derivations of `dx`, assuming the derived relation + is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value, + and then do the same with `dx`'s derivations. + + Assuming the originally specified derived relations are correct is valid, because: + 1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1)) + produce_guards() will catch this and crash before hand. + 2) if the relations are numerically correct but do not match the emitted guard, + for example: + + def forward(self, x, y): + return x.reshape([-1]) + y # guard: s0 * 2 = s1 + inputs = (torch.randn(6, 2), torch.randn(12)) + dx = Dim("dx", min=2, max=32) + dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op + + then this leads to 2 linear equations, and a) produce_guards() is able to solve for + the unique solution of dx = 6 and specialize, and b) the export constraint solver will + raise an issue due to range constraints (a unique solution means not all values in a + range satisfy a guard) and also force specializations. + """ + from torch.export.dynamic_shapes import Dim + + def _check_same_range(c: Mapping[str, int], dim: object) -> bool: + # returns True if c & dim are both min/max ranges with same values + return ( + self._is_dim(dim) + and ("min" in c or "max" in c) + and ( + (dim.min < 2 and c.get("min", 2) == 2) or dim.min == c.get("min", 2) # type: ignore[attr-defined] + ) # let pass if analysis min = 2 and specified min = 0/1 + and dim.max == c.get("max", int_oo) # type: ignore[attr-defined] + ) + + # 1) newly introduced roots + # this part we handle adding newly introduced roots + # these arise from guards like "x.shape[0] % 3 == 0" + # leading to suggested fixes like "dx = 3*_dx" + # extract _dx, and find appropriate min/max values + # + # before, we have something like: + # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2} + # we want instead: + # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3} + introduced_roots: dict[str, str] = {} # map new root -> old root + for k, c in list(results.items()): + if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim + root = next(iter(c["eq"].free_symbols)) + if str(root) not in name_to_dim: + introduced_roots[str(root)] = k + # calculate necessary min & max + modulus, remainder = sympy.polys.polytools.div(c["eq"], root) + c_min = c.get("min", 2) + min_ = math.ceil((c_min - remainder) / modulus) + c_max = c.get("max", int_oo) + max_ = math.floor((c_max - remainder) / modulus) + # create result & dim + results[str(root)] = {"min": min_, "max": max_} + name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_) + # remove old root min/max bounds + c.pop("min", None) + c.pop("max", None) + + # alter derivations that depend on old root, to unify to new root + # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2 + for old_root in introduced_roots.values(): + for k, c in list(results.items()): + if ( + "eq" in c + and isinstance(c["eq"], sympy.Expr) + and str(symbol := next(iter(c["eq"].free_symbols))) == old_root + ): # derived dim with root = old_root + new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1 + new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1 + c["eq"] = new_expr + + # 2) root swapping + # collect all the original roots that are used for calculating values of suggested fixes + # this consists of: + # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim + # 2) {"dy": "dx + 1"} -> dx: root for suggested fix + modified_roots: set[str] = set() + for k, c in results.items(): + if k not in name_to_dim: # _dynamo.export() may handle source directly + continue + if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1) + modified_roots.add(k) + elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2) + root = next(iter(c["eq"].free_symbols)) + assert root is not None + modified_roots.add(str(root)) + + # exclude newly introduced roots, we've already processed these + modified_roots = modified_roots.difference(introduced_roots) + + # evaluate the new value for each root + # this is now either 1) unchanged, 2) refined with a new range, + # or 3) specialized to a concrete value + modified_root_values: dict[str, dict[str, Any]] = {} + for mroot in modified_roots: + swapped_root = True + if mroot in results: + c = results[mroot] + if ("min" in c or "max" in c) or isinstance( # range + c["eq"], int + ): # specialized + # here, the original root is a root Dim or concrete value in results. + # if it is a derived dim, it is swapped, and we handle that below. + if not _check_same_range( + c, name_to_dim[mroot] + ): # ignore if unchanged + modified_root_values[mroot] = c + swapped_root = False + + if swapped_root: + # if the original root has been swapped in results, that means the new root + # is a range (if it had specialized, the original root would have too). + # find this new root, and solve for the original root's range. + for k, c in results.items(): + if k not in name_to_dim: + continue + dim = name_to_dim[k] + if ( + dim.__class__.__name__ == "_DerivedDim" + and dim.root.__name__ == mroot + ): + # only look for min/max root, otherwise root would have specialized + if "min" in c or "max" in c: + expr = sympy.sympify(k) + s = next(iter(expr.free_symbols)) + result = { + "min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type, index] + "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index] + } + if not _check_same_range( + result, + name_to_dim[mroot], # type: ignore[index, arg-type] + ): # ignore if unchanged + modified_root_values[mroot] = result # type: ignore[index] + break + + # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4}) + # we only want to suggest fixes for the root, to avoid derived names. + # also, remove anything in modified_roots, since we either add new modified values after this, + # or have decided they are unchanged. + for k in list(results.keys()): + if k not in name_to_dim: + continue + if self._is_derived_dim(name_to_dim[k]) or k in modified_roots: + del results[k] + + # update results with modified root values + # now results has the following properties: + # - only contains original roots as keys + # - each root is now either specialized, refined, or derived from another original root + results.update(modified_root_values) + + def prettify_results( + self, + original_signature: inspect.Signature, + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]], + constraint_violation_error: object, + forced_specializations: dict[str, str], + ) -> str: + """Format a message for constraint violation erros""" + from torch.export.dynamic_shapes import _get_dim_name_mapping + + if not self._dcp.source_name_to_debug_name: + # nothing to do + return "" + + def transform(s: str, inverse: bool = False) -> str: + for k, v in self._dcp.source_name_to_debug_name.items(): + s = s.replace(k, v) if not inverse else s.replace(v, k) + return s + + results: defaultdict[str, dict[str, Any]] = defaultdict(dict) + if dynamic_shapes is None: + dynamic_shapes = {} + + def flip(op: str) -> str: + if op == "<=": + return ">=" + if op == ">=": + return "<=" + if op == "<": + return ">" + if op == ">": + return "<" + assert op == "==" + return op + + def relation_with_digit(expr: str, op: str, digit: int) -> None: + if op == "<=": + results[expr]["max"] = digit + elif op == "<": + results[expr]["max"] = digit - 1 + elif op == ">=": + results[expr]["min"] = digit + elif op == ">": + results[expr]["min"] = digit + 1 + else: + assert op == "==" + results[expr]["eq"] = digit + + # retrieve dynamic shapes + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + for s in self._static_results.union(self._dynamic_results): + t = transform(s) + if t == s: + continue + left, op, right = re.split(r"( == | <= | >= | < | > )", t) + op = op.strip() + if op == "==" and left == right: + continue + if right.isdigit(): + relation_with_digit(left, op, int(right)) + elif left.isdigit(): + relation_with_digit(right, flip(op), int(left)) + else: + assert op == "==", t + try: + results[left]["eq"] = sympy.sympify(right) + except TypeError: # rhs source is not linked to Dim name + pass + + # order forced specializations based on name + forced_specializations = { + k: forced_specializations[k] + for k in sorted( + forced_specializations.keys(), + key=lambda x: x.split(" = ")[1], + ) + } + + buf = "" + if forced_specializations: + debug_names = set() + for k in forced_specializations: + dim = name_to_dim[k.split(" = ")[0]] + if self._is_derived_dim(dim): + debug_names.add(dim.root.__name__) # type: ignore[attr-defined] + else: + debug_names.add(dim.__name__) + + buf += ( + f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + ) + for s, val in forced_specializations.items(): + buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n" + + self._process_derived_dim_roots(results, name_to_dim) + + dims = [] + others = [] + + # order results by source name + results2 = { + k: results[k] + for k in sorted( + results.keys(), + key=lambda x: transform(x, inverse=True), + ) + } + for k, c in results2.items(): + if "eq" in c: + other = c["eq"] + if isinstance(other, int): + others.append(f"{k} = {other}") + elif _is_supported_equivalence(other): + others.append(f"{k} = {other}") + else: + min_ = c.get("min", None) + if min_ == 2: + min_ = None + max_ = c.get("max", None) + if min_ is not None and max_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})") + elif min_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_})") + elif max_ is not None: + dims.append(f"{k} = Dim('{k}', max={max_})") + else: + dims.append(f"{k} = Dim('{k}')") + + # results2 will get filtered out if no new suggestions, + # this can happen if guards are too complex. + # in that case don't suggest fix + if dims or others: + buf += "\nSuggested fixes:\n " + buf += "\n ".join(dims + others) + + return buf + + +TLS = threading.local() + + +@dataclass(frozen=True) +class ShapeEnvSettings: + """ + Encapsulates all shape env settings that could potentially affect + FakeTensor dispatch. Used when creating dispatch cache keys. + """ + + allow_scalar_outputs: bool + allow_dynamic_output_shape_ops: bool + assume_static_by_default: bool + specialize_zero_one: bool + duck_shape: bool + prefer_deferred_runtime_asserts_over_guards: bool + allow_complex_guards_as_runtime_asserts: bool + trace_asserts: bool + + +@dataclass +class ValueRangesSLoc: + """ + Locations of the guards that triggered lower and upper bound. + """ + + lower: SLoc + upper: SLoc + + +@contextmanager +def _suppress_guards(shape_env: ShapeEnv) -> Iterator[None]: + shape_env._suppress_guards_enter() + try: + yield + finally: + shape_env._suppress_guards_exit() + + +@dataclass +class _FrameLocalResult: + loc: Optional[str] = None + locals: dict[str, Any] = field(default_factory=dict) + symbols: dict[str, str] = field(default_factory=dict) + + +class ShapeEnv: + # This is a wrapper over the actual __init__ function. + # + # Where to add a new constructor parameter to ShapeEnv? + # ===================================================== + # This __init__ function should be used only for parameters related to event recording. + # These are parameters that we don't wish to pass down the road to new ShapeEnv instances + # created from replaying events. + # + # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event + # recording, do so in the _init function. + def __init__( + self, + *, + should_record_events: Optional[bool] = None, + tracked_fakes: Optional[list[Any]] = None, + **kwargs: Any, + ) -> None: + self._init(**kwargs) + + # Disable event recording when replaying. + kwargs["should_record_events"] = False + + from torch.fx.experimental.validator import translation_validation_enabled + + self._translation_validation_enabled = translation_validation_enabled() + + # If not specified, enable event recording if both: + # - Translation validation is on + # - Translation validation bisection is not disabled + self.should_record_events = ( + should_record_events + if should_record_events is not None + else ( + self._translation_validation_enabled + and not config.translation_validation_no_bisect + ) + ) + + # Enable event recording check if both: + # - It should record events + # - The recording check is enabled + self.check_recorded_events = ( + self.should_record_events and config.check_shape_env_recorded_events + ) + + # This will make sure we only record the top-level function call. + self.is_recording = False + # Keep track of the list of tracked fakes. + self.tracked_fakes = tracked_fakes + # List of events for reconstructing ShapeEnv at arbitrary points in time. + self.events: list[ShapeEnvEvent] = ( + [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] + if self.should_record_events + else [] + ) + + # FakeTensor per-ShapeEnv operation cache. This is used for caching + # operations that contain symbolic shapes which have guards on the + # ShapeEnv (so are ShapeEnv-dependent). + # + # NOTE: It's important that SymNodes in this cache have their ShapeEnv + # stripped otherwise you end up with cycles which can only be cleaned + # with the GC. + self.fake_tensor_cache: dict[ + torch._subclasses.fake_tensor._DispatchCacheKey, + torch._subclasses.fake_tensor._DispatchCacheEntry, + ] = {} + + # Pro-tip: if you add new field to ShapeEnv, this affects some accept + # tests. Accept their output with: + # + # EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal + # + def _init( + self, + *, + allow_scalar_outputs: bool = True, + allow_dynamic_output_shape_ops: bool = True, + # NB: These are legacy configuration that help us make good choices + # when the constraint/dynamic dims are not explicitly passed to us. + # Ideally we will fix all call sites to be explicit and not have + # implicit choices, but this apparently was pretty involved. + assume_static_by_default: bool = False, + # Note - On 0/1 specialization + # + # The following options affect decisions we make about eager + # specialization. Disabling them will increase trace time (as we do + # more symbolic reasoning) and can also harm the quality of generated + # code (because inductor may not be able to specialize for bounds + # being equal--although if we later respecialize because of a guard, + # your code may be just as good as it was before.) + # + # When True, eagerly specialize input sizes which have 0/1. + specialize_zero_one: bool = True, + # When True, assume input sizes which have the same size are + # symbolically equal. + duck_shape: Optional[bool] = None, + # For debugging + co_fields: Optional[dict[str, str]] = None, + # When True, whenever safe, we will generate a deferred runtime assert + # instead of a guard whenever we know that an expression must be True, + # otherwise it would be an error, even for backed SymInts (where we + # could ostensibly unconditionally generate guards). This is useful + # for export, where preventing "error checking" sizes from showing up + # in guards is helpful, since these guards in some sense are overly + # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 + prefer_deferred_runtime_asserts_over_guards: bool = False, + # When True, does not emit or raise constraint violation errors on + # implicit guards generated by ops, and defers to runtime assertions + # in the graph instead. For export. + allow_complex_guards_as_runtime_asserts: bool = False, + # XXX Add any new settings that could affect FakeTensor evaluation + # to: torch._subclasses.fake_tensor._ShapeEnvSettings + trace_asserts: bool = False, + ) -> None: + if duck_shape is None: + duck_shape = config.use_duck_shape + + self.settings = ShapeEnvSettings( + # Not directly used by ShapeEnv; indirectly used by FakeTensor + allow_scalar_outputs=allow_scalar_outputs, + allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops, + # End + assume_static_by_default=assume_static_by_default, + specialize_zero_one=specialize_zero_one, + duck_shape=duck_shape, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + trace_asserts=trace_asserts, + ) + + self.guards: list[ShapeGuard] = [] + self.axioms: dict[sympy.Expr, sympy.Expr] = {} + + # A set of ids that have already been allocated. This is used + # for when we allocate symbol ids using the hash of the source + # names to ensure we don't have collisions via linear probing + self.unique_ids: set[int] = set() + # Maps symbolic ints to their original concrete values + # Currently populated from tensors + self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Like var_to_val, but only set when propagate_real_tensors is on. + # Used as last resort to avoid GuardOnDataDependent error + self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Like above, but used exclusively for OBLIVIOUS_SIZE. These + # potentially could be put together but I am not sure, writing out + # the logic individually before abstracting. + self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Maps symbolic ints to their min/max range. These ranges + # are conservative: the int MUST fall in the range, but the + # range may contain ints which may not actually appear in + # practice + self.var_to_range: dict[sympy.Symbol, ValueRanges] = {} + self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {} + self.source_name_to_debug_name: dict[str, str] = {} + self.var_to_sources: dict[sympy.Symbol, list[Source]] = {} + self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {} + # Maps a source to the *original* symbol that was assigned to it + self.source_to_var: dict[str, sympy.Symbol] = {} + # Maps from sympy ints to expressions representing them + # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) + self.replacements: dict[sympy.Symbol, sympy.Expr] = {} + # The sloc of the guard that triggered this replacement to be added + self.replacements_slocs: dict[sympy.Symbol, SLoc] = {} + self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {} + # Set holds a % b expressions that evaluate to 0. + self.divisible: set[sympy.Expr] = set() + # Set that holds "size-like" symbols. When we perform + # "size-oblivious" tests, these can be assumed to be >= 2. + self.size_like: set[sympy.Symbol] = set() + # Duck-shaping says that if two input tensors have the same size, + # they get assigned the same symbolic variable + self.val_to_var: dict[int, sympy.Symbol] = {} + self.unbacked_symfloat_counter = itertools.count() + self.unbacked_symint_counter = itertools.count() + # Similar to guards, but these MUST evaluate to true and can + # only be evaluated at runtime midway through (i.e., they always + # involve unbacked symints) + # + # For efficiency reasons, we index in the following way. Suppose you have + # a runtime assert i0 + i1 <= s1. We pick the most recently allocated + # symbol in the source expression and add the assert to the list for + # that symbol e.g., {i1: [i0 + i1 <= s1]}. + # + # We access the runtime asserts in two situations: + # + # - When we are guarding on an expression, we will attempt to + # statically evaluate it, in case the unbacked SymInts can + # simplify away. If we have a runtime assert, we may be able + # to discharge the guard entirely. We only need to attempt + # runtime asserts that mention freevars of the expression in + # question. + # + # - When we are performing codegen (in Inductor for eager, or + # when finalizing the export FX graph), we need to know what + # extra runtime asserts to insert. Whenever an unbacked + # SymInt comes into scope, all runtime asserts involving it + # become eligible for insertion (so long as all of their other + # free unbacked symbols are also in scope). We technically + # can handle any choice of key by kicking inexpressible asserts + # to the next unbacked symbol to wait on, but if we choose the + # latest key, an assert will only show up at the moment when + # we can actually codegen it. + self.deferred_runtime_asserts: dict[ + Optional[sympy.Symbol], list[RuntimeAssert] + ] = {} + # This exists so we can efficiently invalidate the cache (it's used as + # part of the cache key); otherwise we'd have to iterate through + # deferred_runtime_asserts to compute its length + self.num_deferred_runtime_asserts = 0 + self.log = log + self.log.info("create_env") + self.frozen = False + self.runtime_asserts_frozen = False + self.dim_constraints: Optional[DimConstraints] = None + self.counter: Counter[str] = collections.Counter() + # Mapping from sympy.Symbol to the number of guards which mention this + # symbol + self.symbol_guard_counter: Counter[sympy.Symbol] = collections.Counter() + # A selection of important fields on co_field; solely used for + # signpost_event + self.co_fields = co_fields if co_fields else {} + + # Whenever we allocate a fresh unbacked Symbol, we add it to this + # pending list. Unbacked symbol allocation can occur at unpredictable + # points during meta tensor propagation, but at some point, we + # have to know what the binding site for an unbacked symbol is, and + # this is computed when we actually place the node in the graph. The + # important thing is that we always actually handle every unaccounted + # for unbacked symbol, so this list helps us keep track of them and + # then make sure they are all accounted for. + # + # We could potentially give rise to errors earlier by lexically + # scoping when we do propagation, and only allowing unbacked symbols + # to be allocated at this point in time. However this is inconvenient + # to do in Dynamo, because fake tensor propagation is far from when we + # analyze binding sites (set_example_value), so we do it in a more + # mutatey way. + # + # NB: fresh unbacked symbols NEVER get substitutions applied to them, + # they are binding sites! + self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = [] + + # Version counter used to invalidate cached values + self._prev_cache_key = self._get_key() + self._version_counter = 0 + + # Each time divisible is changed this should be set to True, this is set in _update_version_counter. + self._resimplify_floor_div_axioms = True + + # Cache for FX nodes. + # Maps an already built node a tuple of: + # 1. node's target + # 2. list of arguments + # This drastically reduces the size of the FX graph, avoiding + # duplicated nodes. + self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {} + self.source_to_symbol: dict[str, sympy.Symbol] = {} + + # Suppose you want to replace an unbacked symbol with another + # unbacked symbol. This is error prone because you can cause + # references to unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. + # + # To control for this, we track the order unbacked symbols + # were allocated, and only allow substitutions if they respect + # the dependency from this order; an unbacked symbol can only + # be substituted with unbacked symbols that come before it in the + # order. + # + # This also imposes an ordering on the unbacked symbol binding + # sites themselves: you are not allowed to reorder unbacked symbol + # bindings. At the moment, this is not tracked, but we potentially + # could track this at the IR level using a higher order operator + # with something like effect token tracking. + self.unbacked_alloc_order: dict[sympy.Symbol, int] = {} + + self.user_specialization_stacks: dict[Source, traceback.StackSummary] = {} + self.framework_specialization_stacks: dict[Source, traceback.StackSummary] = {} + + self.trace_asserts = trace_asserts + + self.specializations: OrderedSet[Specialization] = OrderedSet() + + from torch.fx.experimental.validator import translation_validation_enabled + + self._translation_validation_enabled = translation_validation_enabled() + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import TranslationValidator + + self.validator = TranslationValidator() + self.graph = torch.fx.Graph() + # Create an output graph and start inserting before that. + # This is needed when 'deepcopy'-ing this object. + self.graph.inserting_before(self.graph.output(None)) + + # Mapping of each node name to the node itself. + # + # This is useful for matching an FX node from a recorded ShapeEnv.graph + # to the FX node of the ShapeEnv we are running the event on. + # + # Whenever you add a node to self.graph, you must add a mapping to this + # variable. Otherwise, the built FX graph on the replayed ShapeEnv will + # not be valid. + self.name_to_node: dict[str, torch.fx.Node] = {} + + @property + def allow_scalar_outputs(self) -> bool: + return self.settings.allow_scalar_outputs + + @property + def allow_dynamic_output_shape_ops(self) -> bool: + return self.settings.allow_dynamic_output_shape_ops + + @property + def assume_static_by_default(self) -> bool: + return self.settings.assume_static_by_default + + @property + def specialize_zero_one(self) -> bool: + return self.settings.specialize_zero_one + + @property + def duck_shape(self) -> bool: + return self.settings.duck_shape + + @property + def prefer_deferred_runtime_asserts_over_guards(self) -> bool: + return self.settings.prefer_deferred_runtime_asserts_over_guards + + @property + def allow_complex_guards_as_runtime_asserts(self) -> bool: + return self.settings.allow_complex_guards_as_runtime_asserts + + @contextmanager + def patch_source_specialization( + self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] + ) -> Iterator[None]: + """ + Temporarily add symbol-level axioms to the ShapeEnv. This is useful when you want to "fork" + and have parallel universes of ShapeEnvs. For example, we use this when doing multi-graph + compile so we can support various graphs with varying levels of specializations. + + This context manager allows for temporarily adding constraints to the shape environment + based on a specialization function applied to a symbol associated with a source. + + Args: + source: The source of the symbol to specialize + check_fn: A function that takes a sympy Symbol and returns a sympy expression + representing a constraint/specialization to be applied + """ + name = source.name() + sym = self.source_to_var[name] + expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr + new_axioms = dict(self.get_implications(self.simplify(expr))) + added_replacements = {} + + for axiom in new_axioms: + if ( + isinstance(axiom, sympy.Eq) + and isinstance(axiom.lhs, sympy.Symbol) + and isinstance(axiom.rhs, sympy.Integer) + and axiom.lhs not in self.replacements + ): + self.replacements[axiom.lhs] = axiom.rhs + added_replacements[axiom.lhs] = axiom.rhs + self.axioms.update(new_axioms) + + # We need to freeze the ShapeEnv becuase any additional modification of + # the ShapeEnv will cause unsoundness for subsequent specialization calls. + self.frozen = True + try: + yield + finally: + for k in new_axioms: + self.axioms.pop(k, None) + for k in added_replacements: + self.replacements.pop(k, None) + self.frozen = False + + def check_equal(self, other: ShapeEnv) -> None: + """Compare another ShapeEnv for equivalence""" + # ShapeEnv fields that are not relevant for the outcome of + # ShapeEnv.produce_guards call: + # - Debugging variables + # - Translation validation related variables + # - Events recording related variables + non_state_variable_names = ( + "counter", + "log", + "var_to_stack", + "fx_node_cache", + "graph", + "validator", + "check_recorded_events", + "should_record_events", + "is_recording", + "tracked_fakes", + "events", + "source_name_to_debug_name", + "_prev_cache_key", + "_version_counter", + "dim_constraints", + # source locations are OK to diverge + "var_to_range_sloc", + "replacements_slocs", + "_resimplify_floor_div_axioms", + "_expr_sym_node_id", + "user_specialization_stacks", + "framework_specialization_stacks", + ) + + # Mapping of the value of each to-be-compared field into the values that + # should actually be compared. + # + # You should modify this if, for example, the field that holds state and + # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr) + # and the stack when it was added to the set of guards. In order to compare + # it, we throw away the stack information. + def map_value(key: str, value: Any) -> Any: + if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"): + from copy import copy + + # For itertools.count(), we compare the next integer returned + # by the count iterators. Not that we need to copy the iterator + # first. Otherwise we are mutating the object. + return next(copy(value)) + elif key == "guards": + # Transform the list of ShapeGuard into a list of expressions. + return [g.expr for g in value] + elif key == "deferred_runtime_asserts": + # Transform the list of RuntimeAsserts into a list of expressions. + return {s: [ra.expr for ra in ras] for s, ras in value.items()} + elif key == "name_to_node": + # Compare just the set of keys is the same. + return set(value.keys()) + elif key in ( + "symbol_guard_counter", + "pending_fresh_unbacked_symbols", + "fake_tensor_cache", + ): + # Skip this for comparisons + return None + return value + + shape_env_check_state_equal(self, other, non_state_variable_names, map_value) + + def _snapshot_tracked_fakes(self) -> Optional[list[Any]]: + if self.tracked_fakes is None: + return None + + from torch._dynamo.variables.builder import TrackedFake + + def maybe_transform_fake(fake: TrackedFake) -> TrackedFake: + inner_fake = ( + fake.fake + if isinstance(fake.fake, (torch.SymInt, torch.SymFloat)) + else FakeTensorMeta.from_fake(fake.fake) + ) + # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a + # FakeTensorMeta for two reasons: + # 1. this is all the information we need when recording ShapeEnvEvents. + # 2. it works even if each TrackedFake changes its metadata. + return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type] + + return [maybe_transform_fake(fake) for fake in self.tracked_fakes] + + def _last_event_index(self) -> int: + return len(self.events) - 1 + + @contextmanager + def _recording(self) -> Iterator[None]: + self.is_recording = True + try: + yield + finally: + self.is_recording = False + + @record_shapeenv_event() + def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr) -> None: + self._set_replacement(orig_s, new_s, "eliminate_unbacked") + + @record_shapeenv_event() + def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None: + """Used only when propagate_real_tensors; registers a value for an + unbacked symbol, which can be used last resort to resolve hints.""" + log.info("set_unbacked_var_to_val %s = %s", k, v) + self.unbacked_var_to_val[k] = sympy.sympify(v) + + # Unlike set_replacement, this records a shapeenv event + @record_shapeenv_event() + def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol) -> None: + assert isinstance(orig_s, sympy.Symbol), orig_s + assert isinstance(new_s, sympy.Symbol), new_s + assert free_unbacked_symbols(new_s), new_s + assert free_unbacked_symbols(orig_s), orig_s + dest = self.replacements.get(orig_s) + if dest is not None: + assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" + self._set_replacement(orig_s, new_s, "rename_unbacked_to") + self.unbacked_renamings[orig_s] = new_s + if dest is not None: + self._set_replacement(new_s, dest, "rename_unbacked_to_dest") + + @record_shapeenv_event() + def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None: + # TODO: Do something nontrivial when upper_bound is expression + pass + + @record_shapeenv_event() + def _constrain_range_for_size( + self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None + ) -> None: + if min is None: + min = 0 + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + self.size_like.add(a) + + @record_shapeenv_event() + def _constrain_range(self, a: sympy.Expr, min: int, max: int) -> None: + if isinstance(a, sympy.Integer): + if not (min <= int(a) <= max): + raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]") + return + + # TODO: Shouldn't we install a guard if the symbol is backed? Or is the + # semantics that this is an "unchecked" assert (but it this actually + # something useful? Might be better to restrict only for unbacked + # SymInt). + if isinstance(a, sympy.Symbol): + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + + @record_shapeenv_event() + def _constrain_unify(self, a: SymInt, b: SymInt) -> None: + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + # TODO: this does not install a deferred runtime assert yet + + # TODO: Maybe dedupe this with _maybe_guard_rel? + # Update Feb 2024: this is extra important to do, this doesn't handle + # unbacked replacements properly nor does it generate deferred runtime + # asserts + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + else: + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) + assert b.node.shape_env is self + self.replacements[b.node.expr] = sympy.Integer(a) + else: + # TODO: Actually, we can support this as long as one of them is a symbol. + # NB: We can't actually do "unification" as our operators are not + # injective + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert a.node.shape_env is self + if not isinstance(b, SymInt): + self.replacements[a.node.expr] = sympy.Integer(b) + else: + assert a.node.shape_env is b.node.shape_env + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) + new_var = self._find(a.node.expr) + self.replacements[b.node.expr] = new_var + + def _ignore_fresh_unbacked_symbols_tls(self) -> bool: + return getattr(TLS, "ignore_fresh_unbacked_symbols", False) + + @record_shapeenv_event() + def _ignore_fresh_unbacked_symbols_set(self, b: bool) -> bool: + prev = self._ignore_fresh_unbacked_symbols_tls() + TLS.ignore_fresh_unbacked_symbols = b + return prev + + @contextmanager + def ignore_fresh_unbacked_symbols(self) -> Iterator[None]: + """ + Indicates that the newly allocated unbacked SymInts are being + discarded + """ + prev = self._ignore_fresh_unbacked_symbols_set(True) + try: + yield + finally: + self._ignore_fresh_unbacked_symbols_set(prev) + + @record_shapeenv_event() + def freeze(self) -> None: + """Freeze this ShapeEnv to stop accumulating guards + + A frozen ShapeEnv will ignore any further guards generated on it and + only emit a warning which may lead to accuracy problems. + """ + self.frozen = True + + @record_shapeenv_event() + def freeze_runtime_asserts(self) -> None: + """Freeze this ShapeEnv to stop adding deferred runtime asserts. + + We will error if you try to install a new runtime assert when it is + frozen. This would indicate a lowering violation, or perhaps something + we know statically is already True but we are checking it again in a way + that is not clearly dischargeable. + """ + # self.prefer_deferred_runtime_asserts_over_guards = False + self.runtime_asserts_frozen = True + + def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]: + if not self._translation_validation_enabled: + return None + srcname = source.name() + if source not in self.source_to_symbol: + self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) + return self.source_to_symbol[srcname] + + def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None: + if self._translation_validation_enabled: + self.validator.add_var(symbol, type) + + def _add_target_expr(self, expr: SympyBoolean) -> None: + if self._translation_validation_enabled: + self.validator.add_target_expr(expr) + + def _add_assertion(self, expr: SympyBoolean) -> None: + if self._translation_validation_enabled: + self.validator.add_assertion(expr) + + def _check_translation_validate(self) -> None: + if self._translation_validation_enabled: + self.validator.validate() + + @record_shapeenv_event() + def _create_fx_call_function( + self, + op: Callable, + args: tuple, + ) -> tuple[Optional[torch.fx.Node], bool]: + # Cache this tuple in order to avoid duplicated nodes. + node_key = (op, args) + # Flags whether the returned node was cached or not. + fresh = False + + if self._translation_validation_enabled and node_key not in self.fx_node_cache: + # Presence of None in the arguments implies that we should ignore this operation. + if any(a is None for a in args): + # We check if we are not mixing SymNode that should not be ignored + # (fx_node is not None) with those that should (fx_node is None). + assert all(not isinstance(a, torch.fx.Node) for a in args) + return None, fresh + + fresh = True + + # If translation validation is enabled, all arguments must have its + # own FX node. + assert all(a is not None for a in args), ( + f"missing arg in FX graph ({op.__name__}): {args}" + ) + node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) + self.name_to_node[node.name] = node + + return self.fx_node_cache.get(node_key, None), fresh + + def _create_fx_placeholder_and_z3var( + self, + symbol: sympy.Symbol, + type: type, + ) -> Optional[torch.fx.Node]: + if not self._translation_validation_enabled: + return None + + node_key = (self.graph.placeholder, (symbol,)) + + # Check if we haven't added this symbol already. + # If so, skip the placeholder creation, as it + # generates invalid Python code. + if node_key not in self.fx_node_cache: + # Add a Z3 variable according to 'type'. + self._add_z3var(symbol, type) + # Create the FX placeholder out of a mangled name. + mangled_name = re.sub( + r"[^a-zA-Z0-9]", "_", re.sub(r"[()]", "", symbol.name) + ) + node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name) + self.name_to_node[node.name] = node + # Attach the 'symbol' to the placeholder so that we can retrieve + # the Z3 variable later. + node.meta["symbol"] = symbol + + return self.fx_node_cache[node_key] + + def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None: + if self._translation_validation_enabled and node is not None: + self.name_to_node.pop(node.name) + self.graph.erase_node(node) + + def _add_fx_node_metadata(self, node: torch.fx.Node) -> None: + from torch._dynamo.utils import get_current_node + + if self.should_record_events: + node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index() + node.meta[CURRENT_NODE_KEY] = get_current_node() + + @staticmethod + def _suppress_guards_tls() -> bool: + return getattr(TLS, "suppress_guards", False) + + @record_shapeenv_event() + def _suppress_guards_enter(self) -> None: + if not hasattr(TLS, "suppress_guards_stack"): + TLS.suppress_guards_stack = [] + old = self._suppress_guards_tls() + TLS.suppress_guards_stack.append(old) + TLS.suppress_guards = True + + @record_shapeenv_event() + def _suppress_guards_exit(self) -> None: + old = ( + TLS.suppress_guards_stack.pop() + if len(TLS.suppress_guards_stack) > 0 + else False + ) + TLS.suppress_guards = old + + def suppress_guards(self) -> _GeneratorContextManager[None]: + """Context manager to ignore all guards generated inside""" + return _suppress_guards(self) + + def _get_key(self) -> tuple[int, int, int, int]: + """ + Defines the current "state" of the guards we've accumulated in this ShapeEnv. + Determines when we need to invalidate our cache + """ + return ( + len(self.replacements), + len(self.divisible), + self.num_deferred_runtime_asserts, + len(self.unbacked_var_to_val), + ) + + def _update_version_counter(self) -> None: + # if the change to shape env effects self.divisible set + # _resimplify_floor_div_axioms. + # This is used to trigger a resimplication of FloorDiv to CleanDivs + # in implication inside the function resimplify_floor_div. + if len(self.divisible) != self._prev_cache_key[1]: + self._resimplify_floor_div_axioms = True + + # The shape environment is queried orders of magnitude more often than + # it is changed, so we summarise the cache key into a linearly + # increasing version counter which is cheaper to check in _lru_cache + + # Only update version counter if the state actually changed + cur_key = self._get_key() + + if self._prev_cache_key != cur_key: + self._prev_cache_key = cur_key + self._version_counter += 1 + + def _produce_dyn_sizes( + self, + ex_size: Sequence[IntLikeType], + source: Source, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + return self._produce_dyn_sizes_from_int_tuple( + tuple(ex_size), source, symbolic_context + ) + + def _produce_dyn_sizes_from_int_tuple( + self, + tensor_size: Sequence[IntLikeType], + source: Source, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + assert all(not is_symbolic(val) for val in tensor_size), ( + f"Expect size to be a plain tuple of ints but got {tensor_size}" + ) + from torch._dynamo.source import TensorProperty, TensorPropertySource + + _assert_symbol_context(symbolic_context) + dynamic_dims = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + constraint_dims = symbolic_context.constraint_sizes # type: ignore[attr-defined] + size = [] + for i, val in enumerate(tensor_size): + sym = self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.SIZE, i), + dynamic_dims[i], + constraint_dims[i], + do_not_specialize_zero_one=config.backed_size_oblivious, + symbolic_context=symbolic_context, + ) + if ( + isinstance(symbolic_context, StatelessSymbolicContext) + and symbolic_context.specialize_on + ): + for specialization in symbolic_context.specialize_on[i]: + self.specializations.add( + Specialization( + TensorPropertySource(source, TensorProperty.SIZE, i), + specialization, + ) + ) + if ( + config.backed_size_oblivious + and isinstance(sym, sympy.Symbol) # could be static + and symbol_is_type(sym, SymT.SIZE) + ): + self.size_like.add(sym) + size.append(sym) + return size + + def create_symbolic_sizes_strides_storage_offset( + self, + ex: torch.Tensor, + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: + """ + Returns a list of symbolic sizes and strides for the given tensor. + We try our best to express stride in terms of the sizes, so as to not + introduce new symbolic variables. + """ + + ex_size = tuple( + self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size() + ) + ex_stride = tuple( + self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride() + ) + ex_storage_offset = self._maybe_specialize_sym_int_with_hint( + ex.storage_offset() + ) + + return self._create_symbolic_sizes_strides_storage_offset( + ex_size, + ex_stride, + ex_storage_offset, + [_is_dim_dynamic(ex, i) for i in range(ex.dim())], + source, + symbolic_context=symbolic_context, + ) + + # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic"). + # We create symbols in shape_env using the backed hints behind SymInt. + + # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape. + # produce_guards will trigger specializations on the outer stuff + + # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint(). + # + # It's probably good for now but it's important to note that this approach has implications for + # the original shape_env when checking guards in different order. + + # Example: + # --------- + # Consider a function "opt_f" as shown below: + + # @torch.compile() + # def opt_f(x: bool, y: Tensor): + # if x == True: + # return y + torch.randn([4]) + # else: + # return y + # Depending on the sequence of calls, we might install two different sets of guards: + + # 1. opt_f(False, y): + # - "x == False" (always works for any size y) + + # 2. opt_f(True, y): + # - Triggers recompilation and results in guards like: + # - "x == True and y.size(0) == 4" + # - (or "y.size(0) == 4 and x == True") + + # The order of checking the guards matters. In this specific example: + # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, + # we may have an unnessary shape speciliazation for y. + def _maybe_specialize_sym_int_with_hint( + self, maybe_sym: IntLikeType + ) -> IntLikeType: + assert isinstance(maybe_sym, (int, torch.SymInt)) + if is_symbolic(maybe_sym): + assert maybe_sym.node.shape_env is not self, ( + "expect the symbol is created from an shape env other than current one." + ) + return maybe_sym.node.require_hint() + return maybe_sym + + @record_shapeenv_event() + def _create_symbolic_sizes_strides_storage_offset( + self, + # NB: SymInt is allowed here due to nested int, normally you don't + # actually pass true symbolic sizes to this function + ex_size: Sequence[IntLikeType], + ex_stride: Sequence[IntLikeType], + ex_storage_offset: IntLikeType, + is_dim_dynamic: Sequence[bool], + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: + dim = len(ex_size) + + # Reimplement the legacy behavior + if symbolic_context is None: + constraint_sizes: list[DimConstraint] = [None] * dim + constraint_strides: list[DimConstraint] = [None] * dim + dynamic_dims = [] + dynamic_strides = [] + for i in range(dim): + # NB: This is encapsulation breaking! Legacy behavior was + # bad. + if is_dim_dynamic[i]: + r = DimDynamic.DYNAMIC + elif self.assume_static_by_default: + r = DimDynamic.STATIC + else: + r = DimDynamic.DUCK + dynamic_dims.append(r) + dynamic_strides.append(r) + dynamic_dims = [DimDynamic.DUCK] * dim + dynamic_strides = [DimDynamic.INFER_STRIDE] * dim + # symbolic_context is None - set one + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=dynamic_dims, + dynamic_strides=dynamic_strides, + constraint_sizes=constraint_sizes, + constraint_strides=constraint_strides, + ) + # We got a StatelessSymbolicContext + _assert_symbol_context(symbolic_context) + constraint_sizes = symbolic_context.constraint_sizes # type: ignore[attr-defined] + constraint_strides = symbolic_context.constraint_strides # type: ignore[attr-defined] + dynamic_sizes = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + dynamic_strides = symbolic_context.dynamic_strides # type: ignore[attr-defined] + + # TODO: make this configurable from outside symbolic_context; we made a symbolic_context + # decision here where if all sizes are static, we are going to + # specialize all of the inner strides/offset too. We don't have to + # do this, and arguably we should ALWAYS allow for dynamic offset, + # this is cheap. + # TODO: This should be DYNAMIC, using DUCK for BC + dynamic_offset = ( + DimDynamic.STATIC + if all(r == DimDynamic.STATIC for r in dynamic_sizes) + else DimDynamic.DUCK + ) + are_sizes_static = all(r == DimDynamic.STATIC for r in dynamic_sizes) + + assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(dynamic_strides) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(constraint_sizes) == dim + assert len(constraint_strides) == dim + + from torch._dynamo.source import TensorProperty, TensorPropertySource + + size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( + ex_size, source, symbolic_context + ) + stride = self._compute_symbolic_stride( + source, + size, + ex_size, + ex_stride, + dynamic_strides, + constraint_strides, + are_sizes_static, + symbolic_context, + ) + + sym_sizes = [ + self.create_symintnode( + sym, + hint=hint, + source=TensorPropertySource(source, TensorProperty.SIZE, i), + ) + for i, (sym, hint) in enumerate(zip(size, ex_size)) + ] + sym_stride = [] + for i, stride_expr in enumerate(stride): + # NB: Don't duck size the stride; instead use the expression + # we computed + assert stride_expr is not None + sym_stride.append( + self.create_symintnode( + stride_expr, + hint=ex_stride[i], + source=TensorPropertySource(source, TensorProperty.STRIDE, i), + ) + ) + sym_storage_offset = self.create_symintnode( + self.create_symbol( + ex_storage_offset, + TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + dynamic_dim=dynamic_offset, + constraint_dim=None, + symbolic_context=symbolic_context, + ), + hint=ex_storage_offset, + source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + ) + return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset + + def _compute_symbolic_stride( + self, + source: Source, + size: Sequence[sympy.Expr], + ex_size: Sequence[IntLikeType], + ex_stride: Sequence[IntLikeType], + dynamic_strides: Sequence[DimDynamic], + constraint_strides: Sequence[ + Optional[Union[StrictMinMaxConstraint, RelaxedUnspecConstraint]] + ], + are_sizes_static: bool, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + from torch._dynamo.source import TensorProperty, TensorPropertySource + + stride: list[Optional[sympy.Expr]] = [None] * len(size) + candidates: dict[IntLikeType, sympy.Expr] = {} + + # iterate over unbound strides in val ascending order with + # index descending as a tie breaker since for cases like + # [(1, 1), (1, 0)], we want to fill in the right most + # stride first. + val_list = [(val, -i) for i, val in enumerate(ex_stride)] + val_list.sort(key=_nested_int_aware_sort) + + for val, neg_i in val_list: + i = -neg_i + contiguous_stride = ( + i != len(ex_stride) - 1 + and ex_stride[i] == ex_size[i + 1] * ex_stride[i + 1] + ) + if val in (0, 1) and not contiguous_stride: + out_stride = sympy.Integer(val) + else: + dynamic_stride = dynamic_strides[i] + if dynamic_stride == DimDynamic.INFER_STRIDE and val in candidates: + # Set stride to a candidate only for DimDynamic.INFER_STRIDE + out_stride = candidates[val] + else: + # Set INFER_STRIDE to STATIC or DUCK depending on sizes + dyn_stride = dynamic_stride + if dynamic_stride == DimDynamic.INFER_STRIDE: + dyn_stride = ( + DimDynamic.STATIC if are_sizes_static else DimDynamic.DUCK + ) + out_stride = self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.STRIDE, i), + dynamic_dim=dyn_stride, + constraint_dim=constraint_strides[i], + symbolic_context=symbolic_context, + ) + stride[i] = out_stride + candidates[ex_size[i] * val] = size[i] * out_stride + + assert all(x is not None for x in stride) + return stride + + @record_shapeenv_event() + def create_symintnode( + self, + sym: sympy.Expr, + *, + hint: Optional[int], + source: Optional[Source] = None, + ) -> IntLikeType: + """Create a SymInt value from a symbolic expression + + If you know what the current hint value of the SymInt to be created + is, pass it into hint. Otherwise, pass None and we will make our best + guess + + """ + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + out: IntLikeType + if isinstance(sym, sympy.Integer): + if hint is not None: + assert int(sym) == hint + out = int(sym) + else: + # How can this occur? When we mark_unbacked, we end up with a real + # tensor that has hints for all sizes, but we MUST NOT create a + # SymNode with a hint, because we're hiding the hint from our eyes + # with the unbacked Symbol. And in fact, the hint compute may be + # inconsistent with size oblivious tests. + if free_unbacked_symbols(sym): + hint = None + out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_symfloatnode( + self, + sym: sympy.Expr, + *, + hint: Optional[int], + source: Optional[Source] = None, + ) -> FloatLikeType: + """Create a SymFloat value from a symbolic expression""" + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + out: FloatLikeType + if isinstance(sym, sympy.Float): + if hint is not None: + assert float(sym) == hint + out = float(sym) + else: + # You could give this the same treatment as SymInt above if + # you supported mark_unbacked on a float, but it's a kind of + # strange thing to do though because floats don't get 0/1 + # specialization anyway + if free_unbacked_symbols(sym): + assert hint is None, sym + out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_unspecified_symint_and_symbol( + self, value: int, source: Source, dynamic_dim: DimDynamic + ) -> IntLikeType: + """Create a SymInt wrapping a new unspecified symbol""" + return self.create_symintnode( + self.create_unspecified_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + + def create_symboolnode(self, sym: sympy.Expr) -> SymBool: + """Create a SymBool object from a sympy boolean expression""" + # This function is only being used in serialization, so we do not track it + # for validation. + return SymBool(SymNode(sym, self, bool, None)) + + def _log_create_unbacked_symbol( + self, + prefix: str, + symbol: sympy.Symbol, + vr: ValueRanges, + source: Optional[Source] = None, + sym_node: Optional[SymNode] = None, + ) -> None: + is_debug = config.extended_debug_create_symbol is not None and str( + symbol + ) in config.extended_debug_create_symbol.split(",") + sloc: Union[str, SLoc] + if source is None: + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + else: + sloc, maybe_extra_debug = source.name(), "" + log.info( + "%s %s [%s, %s] %s%s", + prefix, + symbol, + vr.lower, + vr.upper, + sloc, + maybe_extra_debug, + stack_info=is_debug, + ) + trace_structured( + "create_unbacked_symbol", + metadata_fn=lambda: { + "symbol": str(symbol), + "node_id": id(sym_node), + "vr": f"[{vr.lower}, {vr.upper}]", + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(), + }, + ) + + @record_shapeenv_event() + def create_unbacked_symfloat(self) -> SymFloat: + """Create a symbolic float without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter) + ) + self.counter["create_unbacked_symbol"] += 1 + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float + sloc = self._get_sloc() + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + sym_node = SymNode(symbol, self, float, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symfloat", symbol, vr, sym_node=sym_node + ) + + return SymFloat(sym_node) + + @record_shapeenv_event() + def create_unbacked_symint(self, source: Optional[Source] = None) -> SymInt: + """Create a symbolic integer without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True + ) + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int + sloc = self._get_sloc() + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + sym_node = SymNode(symbol, self, int, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symint", symbol, vr, source, sym_node=sym_node + ) + + return SymInt(sym_node) + + def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool: + """Check if a sympy symbol matches the naming convention for unbacked symbols""" + return symbol_is_type(symbol, SymT.UNBACKED_INT) + + @record_shapeenv_event() + def create_unbacked_symbool(self) -> SymBool: + """Create a symbolic boolean without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True + ) + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int + sloc = self._get_sloc("default value range for unbacked SymBool") + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) + + sym_node = SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symbool", symbol, vr, sym_node=sym_node + ) + + return SymBool(sym_node) + + @record_shapeenv_event() + def create_unspecified_symbol( + self, + val: Union[int, SymInt, float, SymFloat], + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + symbolic_context: Optional[StatelessSymbolicContext] = None, + ) -> sympy.Expr: + """ + Create a symbol with an unspecified value + + Compared to standard symbols we do not assume the value is positive, + nor do we specialze on zero or one values. + """ + # 'positive' is None for unspecified symbols, since we can't + # assume that it will be neither positive nor negative. + + # We don't want to specialize zero one val for unspecified symbol + # so that we can always get a new symbol despite val. + return self.create_symbol( + val, + source, + dynamic_dim, + constraint_dim, + positive=None, + do_not_specialize_zero_one=True, + symbolic_context=symbolic_context, + ) + + @record_shapeenv_event() + def create_symbol( + self, + val: int, + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + positive: Optional[bool] = True, + do_not_specialize_zero_one: bool = False, + symbolic_context: Optional[StatelessSymbolicContext] = None, + ) -> sympy.Expr: + """Create a new symbol which is tracked by this ShapeEnv""" + # check if constraint_dim is actually static integer + if ( + isinstance(constraint_dim, StrictMinMaxConstraint) + and constraint_dim.vr.lower == constraint_dim.vr.upper + ): + dynamic_dim = DimDynamic.STATIC + if constraint_dim.vr.lower != val: + raise ConstraintViolationError( + f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, " + f"for {source.name()}" + ) + if symbolic_context: + from torch._dynamo.source import TensorPropertySource + + assert isinstance(source, TensorPropertySource) + # TODO: storage_offset handling? + assert source.idx is not None + symbolic_context.dynamic_sizes[source.idx] = dynamic_dim + symbolic_context.constraint_sizes[source.idx] = None + constraint_dim = None + + # see note [Tensor Fakification and Symbol Caching] + source_name = source.name() + if ( + isinstance(symbolic_context, StatefulSymbolicContext) + and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache + ): + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {} + + if ( + isinstance(symbolic_context, StatefulSymbolicContext) + and source_name + and ( + source_name + in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] + ) + ): + return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] + + if dynamic_dim in (DimDynamic.SIZE_LIKE_UNBACKED, DimDynamic.OBLIVIOUS_SIZE): + out = self.create_unbacked_symint(source).node.expr + self._constrain_range_for_size(out) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = out + if dynamic_dim is DimDynamic.OBLIVIOUS_SIZE: + self.oblivious_var_to_val[out] = val + return out + + if do_not_specialize_zero_one: + specialize_zero_one = False + else: + specialize_zero_one = self.specialize_zero_one + + assert isinstance(source, Source), f"{type(source)} {source}" + assert not (positive and val < 0), f"positive set for negative value: {val}" + # It's always sound to allocate a symbol as DYNAMIC. If the user + # constrained the symbol, force the symbolic_context to DYNAMIC, because our + # constraint code will do weird stuff if, e.g., it's duck shaped + if constraint_dim is not None: + dynamic_dim = DimDynamic.DYNAMIC + + if dynamic_dim is DimDynamic.STATIC: + out = sympy.Integer(val) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = out + return out + + elif dynamic_dim is DimDynamic.DUCK: + # duck_shape can be used to globally turn off duck shaping, even + # if it was requested + duck = self.duck_shape + elif dynamic_dim is DimDynamic.DYNAMIC: + duck = False + else: + raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}") + + sloc = self._get_sloc() + + if val in (0, 1) and specialize_zero_one: + if val == 0: + return sympy.S.Zero + else: + return sympy.S.One + elif not duck or val not in self.val_to_var: + # If we're not duck shaping, we always create a new symbol + # Even if we're duck shaping, if we haven't seen this particular + # value before, we also create a new symbol + symbol_id = self._generate_unique_id(source.name()) + if type(val) is int or is_nested_int(val): + sympy_expr = make_symbol( + SymT.SIZE, symbol_id, positive=positive, integer=True + ) + else: + sympy_expr = make_symbol( + SymT.FLOAT, symbol_id, positive=positive, real=True + ) + self.source_to_var[source_name] = sympy_expr + # We always associate vars to vals + if isinstance(val, int): + self.var_to_val[sympy_expr] = sympy.Integer(val) + elif isinstance(val, float): + self.var_to_val[sympy_expr] = sympy.Float(val) + else: + # Only used for jagged layout nested tensors + self.var_to_val[sympy_expr] = SingletonInt( + val.node.nested_int(), coeff=val.node.nested_int_coeff() + ) + + # Do the appending later, because we always want to populate this + self.var_to_sources[sympy_expr] = [] + # Create a Z3 variable for the new symbol. + self._add_z3var(sympy_expr, int) + + if duck: + # Make sure to reuse this symbol for subsequent duck shaping + self.val_to_var[val] = sympy_expr + + if isinstance(val, int): + if positive: + # Add assertions for the newly created symbols + self._add_assertion(sympy_expr > 1) + + # Apply default range, which assumes not zero-one + self.var_to_range[sympy_expr] = self._default_value_range( + do_not_specialize_zero_one + ) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc( + self._get_sloc( + "user code shown is first use of this value--the guard itself is not " + "due user code but due to 0/1 specialization in the framework; to " + "avoid specialization try torch._dynamo.mark_unbacked(tensor, dim)" + if self.specialize_zero_one + else None + ), + sloc, + ) + else: + self.var_to_range[sympy_expr] = ( + self._default_unspecified_value_range() + ) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) + + # Small performance optimization: if we have a min-max constraint, + # we can proactively narrow to that range + if isinstance(constraint_dim, StrictMinMaxConstraint): + assert not duck + self._update_var_to_range( + sympy_expr, constraint_dim.vr, is_constraint=True + ) + + vr = self.var_to_range[sympy_expr] + assert vr.is_int + + if val not in vr: + raise ConstraintViolationError( + f"{val} not in range [{vr.lower}, {vr.upper}]" + ) + + range_str = f"[{vr.lower}, {vr.upper}]" + elif isinstance(val, float): + self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) + range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float + else: + # Skip var_range logic for SingletonInt + # Only used for jagged layout nested tensors + range_str = "" + + r = sympy_expr + + is_debug = config.extended_debug_create_symbol is not None and str( + sympy_expr + ) in config.extended_debug_create_symbol.split(",") + maybe_more_info = "" + if not is_debug and os.getenv("TORCHDYNAMO_EXTENDED_ADVICE", "1") not in ( + "0", + "", + ): + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}" ' + "or to suppress this message run with " + 'TORCHDYNAMO_EXTENDED_ADVICE="0"' + ) + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + self.log.info( + "create_symbol %s = %s for %s %s %s%s%s", + sympy_expr, + val, + source.name(), + range_str, + sloc, + maybe_more_info, + maybe_extra_debug, + stack_info=is_debug, + ) + trace_structured( + "create_symbol", + metadata_fn=lambda: { + "symbol": str(sympy_expr), + "val": repr(val), + "vr": range_str, + "source": source.name(), + "user_stack": structured.from_traceback( + TracingContext.extract_stack() + ), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + + self.counter["create_symbol"] += 1 + else: + # This implements duck-shaping: input sizes that match are assigned + # the same symint + r = self.val_to_var[val] + self.source_to_var[source_name] = r + self.log.debug("create_symbol %s duck sized %s", r, source.name()) + + if isinstance(r, sympy.Symbol): + r_sources = self.var_to_sources[r] + r_sources.append(source) + if not source.is_ephemeral() and r_sources[0].is_ephemeral(): + # prefer non-ephemeral source first since it may be guarded on later + r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0] + + # This ensures we get zeros in symbol_guard_counts, which makes + # some queries simpler (since we will accumulate mass on 0 this + # way) + self.symbol_guard_counter[r] = 0 + + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = r + return r + + def add_var_to_val(self, expr: sympy.Symbol, val: int) -> None: + """Adds a new symbol to the symbolic environment.""" + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) + assert expr not in self.var_to_val, f"{expr} already exists" + self.var_to_val[expr] = sympy.Integer(val) + + def _debug_name(self, source: Source) -> str: + src_name = source.name() + return self.source_name_to_debug_name.get(src_name, src_name) + + def _render_range_for_constraint_violation( + self, source: Source, c: Union[StrictMinMaxConstraint, RelaxedUnspecConstraint] + ) -> str: + if isinstance(c, StrictMinMaxConstraint): + lower, upper = c.vr.lower, c.vr.upper + default = self._default_value_range() + if lower <= default.lower: + lower = None + if upper >= default.upper: + upper = None + c_render = ( + f"{self._debug_name(source)} = {source.name()} in the specified range" + ) + if lower is not None and upper is not None: + c_render += f" {lower} <= {self._debug_name(source)} <= {upper}" + elif lower is None and upper is not None: + c_render += f" {self._debug_name(source)} <= {upper}" + elif lower is not None and upper is None: + c_render += f" {lower} <= {self._debug_name(source)}" + return c_render + return c.render(source) + + def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]: + """ + Like produce_guards_verbose, but only returns the non-verbose python guard expressions + (no verbose guards produced.) + """ + return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs + + def produce_guards_verbose( + self, + placeholders: Sequence[FakeTensor], + sources: Sequence[Source], + source_ref: Callable[[Source], str] = lambda n: n.name(), + *, + guards: Optional[list[ShapeGuard]] = None, + input_contexts: Optional[DimList[SymbolicContext]] = None, + # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). + # (See docs on EqualityConstraint for details of the encoding.) + equalities_inputs: Optional[EqualityConstraint] = None, + _simplified: bool = False, + # Indicates if we should produce guards for known static values. + ignore_static: bool = True, + langs: tuple[str, ...] = ("python", "verbose_python"), + ) -> list[_ShapeGuardsHelper]: + """ + Generates a list of guards strings which, when evaluated in a context that + defines tensors for all the sources, returns True or False depending + on if the guards in the list evaluated to True or not. Primarily used by Dynamo, + but this is also helpful for manual testing of guards (see + evaluate_guards_for_args) + + For convenience in testing, a source is allowed to be a str, + in which case we will assume it is a LocalSource + + simplified lets you omit duck sizing, equality and 0/1 guards. + This is useful for testing when you don't care about the boilerplate + guards, and it may be helpful for user output too (be careful though; + some equality guards are nontrivial! It would be nice to get simplified + output to print them too). It's private because it's not + intended for normal use + + Returns guards in python and python with verbose comments (verbose) by + default. + """ + self.log.info("produce_guards") + + # Check if we get to the same ShapeEnv state by replaying the recorded events. + # This will create a new ShapeEnv instance, and call all recorded function + # calls on this new instance. Finally, it will check whether this new instance + # has equal state. + # + # It's important that we do it in the begining of this function, since it modifies + # self.dim_constraints through its execution. Changes that happen in this method + # aren't interesting, since this is the function call we wish to reproduce at the + # end. If we wish to simply reproduce ShapeEnv instances even after this call, + # this method should also be recorded. + if self.check_recorded_events: + shape_env = replay_shape_env_events(self.events) + self.check_equal(shape_env) + + assert len(placeholders) == len(sources), ( + f"len({placeholders}) != len({sources})" + ) + Tensorlike = (torch.Tensor, FakeTensorMeta) + + def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: + return StatelessSymbolicContext( + # Ignored; only the constraints part is relevant below. + dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(), + dynamic_strides=[DimDynamic.INFER_STRIDE] * t.dim(), + constraint_sizes=[None] * t.dim(), + constraint_strides=[None] * t.dim(), + ) + + # Expand optional inputs, or verify invariants are upheld + if input_contexts is None: + input_contexts = [ + _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None + for t in placeholders + ] + else: + assert len(input_contexts) == len(placeholders) + for i, (t, context) in enumerate(zip(placeholders, input_contexts)): + if isinstance(t, Tensorlike): + if context is None: + input_contexts[i] = _create_no_constraints_context(t) + else: + assert isinstance(t, (SymInt, int, SymFloat, float)) + assert not isinstance(context, list) + + # It took a lot of sweat to figure out the algorithm here. Let's + # explain how it works. + # + # The ShapeEnv lifecycle looks something like this: + # + # - For each input, you either generate a fresh Sympy symbol (s0) to + # represent its value (a binding site), or you reuse some + # preexisting symbol or expression, skipping the symbol allocation + # (e.g., duck sizing to a preexisting symbol, or expressing a + # stride as a multiplication of a separate stride and size.) + # Naively, you might expect to bind a fresh Sympy symbol for + # every input, but this is fairly wasteful as most of these + # symbols immediately simplify away, and if you don't eagerly + # specialize, e.g., 0/1 symbols, you end up with very complicated + # expressions that are not optimizable in practice. + # + # - You perform some compute on these symbols, occasionally + # introducing guards on boolean expressions on these symbols. + # In particular, whenever we guard on equality (_maybe_guard_rel), + # we can simplify shapes; e.g., when s0 == s1 * 2, we can now + # replace all occurrences of s0 with s1 * 2. Sometimes, a + # boolean expression evaluation doesn't introduce a guard, as + # the guard is already entailed by the simplifications we have + # applied. + # + # - In the end, you have a bunch of replacements (saying how to + # simplify shapes) and a bunch of guards (all the equality guards + # are trivial, because they're covered by the replacements). + # + # From the ShapeEnv, we must generate a Python expression that, when + # evaluated on a set of inputs, tells us whether or not these boolean + # expressions would have evaluated in the same way. However, + # we cannot easily compute this, as we elide recording boolean + # expressions when we think they are vacuously true. Thus, we seek + # an approximation: we must generate an expression, if true, would have + # produced an "equivalent" ShapeEnv, which would answer guard + # expressions in the same way. + # + # Our notion of equivalence is a bit subtle. For example, consider + # the ShapeEnv created from an input of size (5, 4) versus (4, 4) + # (no other guards.) Duck sizing would generate (s0, s1) in the first + # case but (s0, s0) in the second. We do NOT assume that size + # variables are disjoint; so in fact a graph that assumes the input + # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not + # vice versa. However, consider an analogous case (1,) versus (2,). + # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT + # subsume the (1,) graph because we assume that any size variables + # is NOT 0/1 (and make simplifications according to this; e.g., if + # we queried s0 == 0, we would immediately return False without + # returning a guard.) + # + # So, it is perhaps easier to flip things on their head: the guard + # expressions we generate here say what simplifications are valid, + # and what are not. Below, we explain each of the guard expressions + # we generate + + # TODO: Make this more efficient by binding all the size/stride/offsets + # to locals before performing tests on them. + + from torch._dynamo.source import TensorProperty, TensorPropertySource + + # Actual codegen must be delayed as we don't necessarily know what + # the symbol mapping is + input_guards = [] + + symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict( + list + ) + symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = ( + collections.defaultdict(set) + ) + constraint_violations: list[tuple[bool, str, Callable[[], str]]] = [] + + printers: list[_ShapeGuardPrinter] = [] + py_printer = ShapeGuardPythonPrinter( + symbol_to_source, source_ref, self.var_to_sources + ) + for lang in langs: + if lang in ["python", "verbose_python"]: + printers.append(py_printer) + elif lang == "cpp": + printers.append( + _ShapeGuardCppPrinter( + symbol_to_source, source_ref, self.var_to_sources + ) + ) + else: + raise NotImplementedError(f"Unknown lang: {lang}") + + def record_constraint_violation( + warn_only: bool, + debug_name: str, + msg: str, + hint: Optional[Callable[[], str]] = None, + ) -> None: + constraint_violations.append( + (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg) + ) + + def is_dim(src: object) -> TypeGuard[TensorPropertySource]: + return ( + isinstance(src, TensorPropertySource) + and src.prop is TensorProperty.SIZE + ) + + if equalities_inputs: + source_index = {} + for i, src in enumerate(sources): + source_index[src.name()] = i + + def get_expression(tensor_dim_src: Source) -> sympy.Expr: + fake = placeholders[source_index[tensor_dim_src.base.name()]] # type: ignore[attr-defined] + assert tensor_dim_src.idx is not None # type: ignore[attr-defined] + symint = fake.shape[tensor_dim_src.idx] # type: ignore[attr-defined] + if isinstance(symint, torch.SymInt): + return symint.node.expr + else: + assert type(symint) is int, f"Expected int, got {type(symint)}" + return sympy.Integer(symint) + + for src1, src2 in equalities_inputs.source_pairs: + expr1, expr2 = get_expression(src1), get_expression(src2) # type: ignore[] + # Check whether given input shape values satisfy a specified equation s = s'. + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2)) + if not concrete_val: + raise ConstraintViolationError( + f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}" + " is not equal to " + f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}" + ) + + for srcEq, root, fn in equalities_inputs.derived_equalities: + expr1 = get_expression(srcEq) + # recall that root is either a phantom symbol or an input source + expr2, debug_name = ( + (root, self.var_to_sources[root][0].name()) + if isinstance(root, sympy.Symbol) + else (get_expression(root), self._debug_name(root)) + ) + expr2_ = fn(expr2) + # Check whether given input shape values satisfy a specified equation s = fn(s'). + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_)) + if not concrete_val: + raise ConstraintViolationError( + f"Expected input {srcEq.name()} to be equal to " + f"{fn(sympy.Symbol(debug_name))}, " + f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, " + f"but got {expr1.xreplace(self.var_to_val)}" + ) + + for phantom_symbol in equalities_inputs.phantom_symbols: + # we created additional phantom symbols that are not input shape dimensions + symbol_to_source[phantom_symbol].extend( + self.var_to_sources[phantom_symbol] + ) + + # How do we know what the value of s0 is? Fresh variables can only be + # bound by inputs, so there MUST be some other input which binds the + # variable. If there is no such input, this is an error in our + # system. We record where all symbols come from, to help you diagnose + # why those symbols didn't occur. + # + # In fact, generally speaking it is only possible for the "outermost" + # user of a ShapeEnv to evaluate the guards, because some inputs may + # not be available to inner levels. For example, Dynamo can guard on + # tensors that never actually become graph arguments (they are + # pruned). In this case, only Dynamo knows about these arguments. + def track_symint( + source: Source, val: IntLikeType, constraint: DimConstraint = None + ) -> None: + log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint) + assert not isinstance(val, SymInt) or is_symbolic(val) + + if isinstance(val, SymInt) and val.node.maybe_as_int() is not None: + val = val.node.maybe_as_int() + + if isinstance(val, SymInt): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + if constraint is not None and not isinstance( + constraint, RelaxedUnspecConstraint + ): + symbol_to_constraints[s].add(constraint) + else: + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + # try inferring the ranges of the expr s + sym_vrs = { + x: self.var_to_range.get(x, None) for x in s.free_symbols + } + if any(vr is None for vr in sym_vrs.values()): + # some of the free symbols in s don't have ranges + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + if s.is_number: + i = int(s) + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if i not in (0, 1): + constraint_violated = True + if constraint_violated: + assert constraint is not None + + def hint(s: sympy.Expr) -> str: + sexpr = py_printer.doprint(s) + return f"{sexpr}." + + var_with_range = self._render_range_for_constraint_violation( + source, constraint + ) + msg = ( + f"Not all values of {var_with_range} are valid because " + f"{self._debug_name(source)} was inferred to be equal to " + ) + record_constraint_violation( + constraint.warn_only, + self._debug_name(source), + msg, + hint=functools.partial(hint, s), + ) + + input_guards.append((source, s)) + else: + s = sympy.Integer(val) + input_guards.append((source, s)) + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + if not ( + s == constraint.vr.lower == constraint.vr.upper + ): # allow static constraints + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if val not in (0, 1): + constraint_violated = True + if constraint_violated: + assert constraint is not None + var_with_range = self._render_range_for_constraint_violation( + source, constraint + ) + user_stack = self.user_specialization_stacks.get(source, None) + framework_stack = self.framework_specialization_stacks.get( + source, None + ) + msg = ( + f"You marked {self._debug_name(source)} as dynamic but your code " + f"specialized it to be a constant ({val}). If you're using mark_dynamic, " + f"either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, " + f"replace it with either Dim.STATIC or Dim.AUTO." + + ( + "\n\nFramework stack:\n" + "".join(framework_stack.format()) + if framework_stack + else "" + ) + + ( + "\n\nUser stack:\n" + "".join(user_stack.format()) + if user_stack + else "" + ) + ) + record_constraint_violation( + constraint.warn_only, self._debug_name(source), msg + ) + + def track_symfloat(source: Source, val: FloatLikeType) -> None: + log.debug("track_symfloat %s %s", LazyString(source.name), val) + assert not isinstance(val, SymFloat) or is_symbolic(val) + + if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None: + val = val.node.maybe_as_float() + + if isinstance(val, SymFloat): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + input_guards.append((source, s)) + else: + s = sympy.Float(val) + input_guards.append((source, s)) + + for t, source, context in zip(placeholders, sources, input_contexts): + if isinstance(source, str): + from torch._dynamo.source import LocalSource + + source = LocalSource(source) + assert isinstance(source, Source) + if t is None: + continue + if isinstance(t, (SymInt, int)): + constraint = ( + None if context is None else getattr(context, "constraint", None) + ) + track_symint(source, t, constraint) + continue + elif isinstance(t, (SymFloat, float)): + track_symfloat(source, t) + continue + assert isinstance(t, Tensorlike) + if is_traceable_wrapper_subclass(t): + from torch._dynamo.source import AttrSource + + assert isinstance(context, SubclassSymbolicContext) + + # For subclasses, we need to track symints on BOTH the outer + # and inner tensors. + # TODO: type this better + sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [ + (source, t, context.constraint_sizes, context.constraint_strides) + ] + attrs, _ = t.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(t, attr) + inner_context = context.inner_contexts[attr] + sources_tensors_constraints.append( + ( + AttrSource(source, attr), + inner_t, + inner_context.constraint_sizes, # type: ignore[attr-defined] + inner_context.constraint_strides, # type: ignore[attr-defined] + ) + ) + else: + sources_tensors_constraints = [ + (source, t, context.constraint_sizes, context.constraint_strides) # type: ignore[attr-defined] + ] + + for ( + src, + curr_t, + constraint_size, + constraint_stride, + ) in sources_tensors_constraints: + if is_sparse_any(curr_t): + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource( + src, TensorProperty.SIZE, i + ) + track_symint(property_source, ss, constraint_size[i]) + else: + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource( + src, TensorProperty.SIZE, i + ) + track_symint(property_source, ss, constraint_size[i]) + for i, ss in enumerate(curr_t.stride()): + property_source = TensorPropertySource( + src, TensorProperty.STRIDE, i + ) + track_symint(property_source, ss, constraint_stride[i]) + track_symint( + TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), + curr_t.storage_offset(), + ) + + # 1. Every input must equal the final simplified symbolic expression + # stored on the placeholder. Given a placeholder (s0*2, s1), + # if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3. + # This does a lot of work: it covers duck sizing and equality guards. + all_exprs: list[list[str]] = [[] for _ in langs] + self.dim_constraints = DimConstraints( + symbol_to_source, + self.var_to_val, + set(symbol_to_constraints.keys()), + self.source_name_to_debug_name, + ) + + if not _simplified: + for source, expr in input_guards: + srcname = source.name() + if self._translation_validation_enabled: + # Ignore sources that were not turned into SymInts. + if srcname in self.source_to_symbol: + self._add_target_expr( + sympy.Eq(self.source_to_symbol[srcname], expr) + ) + + # Small optimization + if ( + isinstance(expr, sympy.Symbol) + and symbol_to_source.get(expr) + and source == symbol_to_source[expr][0] + ): + continue + + # This logic excludes static values found on tensors from guarding, because + # dynamo's check_tensor_fn does that (see guards.cpp). + # However, for non tensor sources, we still need to guard here. + if ignore_static and isinstance(source, TensorPropertySource): + if expr.is_number: + self.log.debug( + "Skipping guard %s", f"{source_ref(source)} == {expr}" + ) + continue + + if is_dim(source): + self.dim_constraints.add_equality(source, expr) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + res = f"{printer.print_source(source)} == {printer.doprint(expr)}" + + if lang == "verbose_python": + if (s0 := self.source_to_var.get(srcname)) is not None: + if source != self.var_to_sources[s0][0]: + res = ( + f"{res} # duck sizing added this equality because these " + f"variables had the same size {self.var_to_val[s0]} " + "(to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)" + ) + elif (sloc := self.replacements_slocs.get(s0)) is not None: + res = f"{res} # {sloc}" + else: + res = f"{res} # (unknown var {s0}, please file a bug)" + else: + res = f"{res} # (unknown source {srcname}, please file a bug)" + exprs.append(res) + + if ( + isinstance(source, TensorPropertySource) + and source.prop is TensorProperty.SIZE + and equalities_inputs + and len(expr.free_symbols) == 1 + ): + symbol = next(iter(expr.free_symbols)) + if ( + isinstance(expr, sympy.Symbol) + and expr in symbol_to_constraints + and not equalities_inputs.is_equal( + source, symbol_to_source[expr][0] + ) + ): + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} and " + f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} " + "must always be equal." + ) + record_constraint_violation( + equalities_inputs.warn_only, self._debug_name(source), msg + ) + + if ( + not isinstance(expr, sympy.Symbol) + and symbol in symbol_to_constraints + and not equalities_inputs.is_derived( + source, + symbol_to_source[symbol][0], + lambda x: expr.xreplace({symbol: x}), + ) + ): + src = symbol_to_source[symbol][0] + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} must always be related to " + f"the values of {self._debug_name(src)} = {src.name()} by " + f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}." + ) + record_constraint_violation( + equalities_inputs.warn_only, self._debug_name(source), msg + ) + + # NB: Not necessary to report constraint violations here: + # constraints are guaranteed to be on symbols (we've already + # caught constants and non-atomic expressions), so we only + # have relational constraints, but we don't support those + # at the moment + + # 2. Every guard must evaluate to True (but remember many guards + # like s0 == s1*2 because trivial due to simplification) + issued = set() + + def issue_guard(guard: ShapeGuard) -> None: + expr = self.simplify(guard.expr) + + # Avoid re-issueing the same guard. + if expr in issued: + return + + issued.add(expr) + + try: + is_trivial = False + if any( + is_dim(source) + for s in expr.free_symbols + for source in symbol_to_source[s] + ): + assert self.dim_constraints is not None + is_trivial = self.dim_constraints.add(expr) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + guard_expr = printer.doprint(expr) + if lang == "verbose_python": + guard_expr = f"{guard_expr} # {guard.sloc}" + exprs.append(guard_expr) + + self._add_target_expr(expr) + # A non-relational constraint on a single sizevar can violate + # a constraint + if not is_trivial and len(expr.free_symbols) == 1: + symbol = next(iter(expr.free_symbols)) + source = symbol_to_source[symbol][0] + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + var_with_range = ( + self._render_range_for_constraint_violation(source, c) + ) + msg = ( + f"Not all values of {var_with_range} " + f"satisfy the generated guard {py_printer.doprint(expr)}." + ) + record_constraint_violation( + c.warn_only, self._debug_name(source), msg + ) + elif isinstance(c, RelaxedUnspecConstraint): + # This is fine, we allow guards here as long as it + # didn't constrain it to one value (we don't + # actually know this; this depends on our + # ValueRanges reasoning capability) + pass + else: + raise AssertionError(f"unrecognized constraint {c}") + except Exception: + self.log.warning("Failing guard allocated at %s", guard.sloc) + raise + + # First, issue all guards. + # This removes all the checks that follow from bounds + # We could simply emit those and also the bounds 2 <= size when necessary + for guard in guards if guards is not None else self.guards: + if ( + self._maybe_evaluate_static( + guard.expr, axioms=(), size_oblivious=guard.size_oblivious + ) + is not None + ): + continue + issue_guard(guard) + + # Because there are guards that export's constraint solver can suggest good fixes for, that we may have + # deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards), + # we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts, + # but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide + # suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph). + for ra in self.deferred_runtime_asserts.get(None, []): + if self._maybe_evaluate_static(ra.expr, axioms=()) is not None: + continue + expr = self.simplify(ra.expr) + self.dim_constraints.add(expr) + + # 3. Every symbol must be within its value range (this handles 0/1 + # specialization too). + for symbol, sources in symbol_to_source.items(): + r = self.var_to_range.get(symbol) + if r is None: + continue + vr_sloc = self.var_to_range_sloc[symbol] + + assert sources + bounds = [] + rf = source_ref(sources[0]) + verbose_expr = "" + if r.lower not in (-sympy.oo, -int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Ge(symbol, r.lower)) + # Only print lower bound in simplified mode if it is not the + # default + if not _simplified or r.lower != self._default_value_range().lower: + bounds.append(sympy.Le(r.lower, symbol, evaluate=False)) + verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}" + if r.upper not in (sympy.oo, int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Le(symbol, r.upper)) + # nontrivial upper bound is always interesting + bounds.append(sympy.Le(symbol, r.upper, evaluate=False)) + if verbose_expr: + verbose_expr = f"{r.lower} <= {rf} <= {r.upper} # {vr_sloc.lower} and {vr_sloc.upper}" + else: + verbose_expr = f"{rf} <= {r.upper} # {vr_sloc.upper}" + if bounds: + bound = sympy.And(*bounds, evaluate=False) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "verbose_python": + exprs.append(verbose_expr) + else: + exprs.append(printer.doprint(bound)) + # NB: verbose_exprs are done above + + # Check constraints + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + # TODO: With int_oo, I think this condition is a noop + # now + if not (c.vr & self._default_value_range()).issubset(r): + source = sources[0] + + expr = sympy.And( + sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper) + ) + guard_expr = py_printer.doprint(expr) + var_with_range = ( + self._render_range_for_constraint_violation(source, c) + ) + msg = f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}" + record_constraint_violation( + c.warn_only, + self._debug_name(source), + msg, + ) + # We NaN specialize, which means similar to 0/1 specialization we + # should assume that the float is NOT nan. This is load bearing + # if you have something like an equality guard, nan will play + # merry hell with the reasoning. + if symbol_is_type(symbol, SymT.FLOAT): + res = f"not math.isnan({py_printer.print_source(sources[0])})" + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "verbose_python": + exprs.append( + f"{res} # implicit guard for float input due to NaN specialization in the framework" + ) + elif lang == "python": + exprs.append(res) + elif lang == "cpp": + exprs.append(f"~std::isnan({printer.print_source(sources[0])})") + else: + raise NotImplementedError(f"Unimplemented for lang: {lang}") + + if constraint_violations: + warn_msgs: list[str] = [] + error_msgs: list[str] = [] + debug_names = set() + for warn_only, debug_name, msg_cb in constraint_violations: + if warn_only: + str_msg = f" {len(warn_msgs) + 1}. {msg_cb()}" + warn_msgs.append(str_msg) + else: + str_msg = f" - {msg_cb()}" + error_msgs.append(str_msg) + debug_names.add(debug_name) + if len(error_msgs) > 0: + debug_names_str = ", ".join(sorted(debug_names)) + err = "\n".join(error_msgs) + raise ConstraintViolationError( + f"Constraints violated ({debug_names_str})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + f"{err}" + ) + elif len(warn_msgs) > 0: + log.debug("%s Warning only constraints violated", len(warn_msgs)) + + signpost_event( + "dynamic", + "produce_guards", + { + **self.co_fields, + **self.counter, + "num_guards": len(all_exprs[0]), + "free_symbols": sum(1 for v in symbol_to_source.values() if v), + # The keys are meaningless from an aggregate perspective, so + # don't include them. Biggest first. + "symbol_guard_counts": sorted( + self.symbol_guard_counter.values(), reverse=True + ), + }, + ) + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import PopulateValidator + + # Add all deferred runtime assertions; these are not technically + # handled by produce_guards but we need to put them in the target + # set + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + self._add_target_expr(ra.expr) + + # Add value range bound guards for all symbols with no trivial bounds. + # Reason: '_maybe_evaluate_static' may eliminate guards based on the + # refined value ranges. + for sym, vr in self.var_to_range.items(): + if vr.lower not in (-sympy.oo, -int_oo): + self._add_target_expr(sympy.Le(vr.lower, sym)) + if vr.upper not in (sympy.oo, int_oo): + self._add_target_expr(sympy.Le(sym, vr.upper)) + + # Before validating, populate the input of the validator with the + # built FX graph. + with fx_traceback.preserve_node_meta(): + PopulateValidator(self.graph, self.validator).run() + + # Only run translation validation when we are not passing custom guards + if guards is None: + self._check_translation_validate() + + helpers: list[_ShapeGuardsHelper] = [] + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "cpp": + assert isinstance(printer, _ShapeGuardCppPrinter) + helpers.append(_CppShapeGuardsHelper(exprs, printer.source_to_symbol)) + else: + helpers.append(_ShapeGuardsHelper(exprs)) + return helpers + + def produce_guards_expression( + self, + placeholders: Sequence[Union[SymInt, FakeTensor]], + *, + guards: Optional[list[ShapeGuard]] = None, + ignore_static: bool = True, + ) -> Optional[str]: + """ + Expected to be used with evaluate_guards_expression(). Produces the guards + for the given placeholders and returns a string expression to be evaluated + by evaluate_guards_expression given concrete values for the placeholders. + """ + from torch._dynamo.source import LocalSource + + arg_names = [f"t{i}" for i in range(len(placeholders))] + produced_guards = self.produce_guards( + placeholders, + [LocalSource(a) for a in arg_names], + guards=guards, + ignore_static=ignore_static, + ) + if produced_guards: + return " and ".join(produced_guards) + return None + + def evaluate_symexpr(self, code: str) -> Union[int, float, bool]: + """ + To be used by compile_fx to evaluate symexprs + """ + args = {str(e): val for e, val in self.var_to_val.items()} + return eval(code, SYMPY_INTERP, args) + + def deserialize_symexpr(self, code: str) -> Union[SymInt, SymFloat, SymBool]: + """ + To be used by compile_fx to deserialize symexprs + """ + args = { + str(e): SymInt(SymNode(e, self, int, int(val), fx_node=None)) + for e, val in self.var_to_val.items() + } + return eval(code, SYMPY_INTERP, args) + + def evaluate_guards_expression(self, code: str, args: Sequence[object]) -> bool: + """ + Expected to be used with produce_guards_expression(). Evaluates an expression + generated by produce_guards_expression for the given concrete args. + """ + arg_names = [f"t{i}" for i in range(len(args))] + return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))}) + + def evaluate_guards_for_args( + self, + placeholders: Sequence[FakeTensor], + args: Sequence[Tensor], + *, + ignore_static: bool = True, + ) -> bool: + """Generate guards for a graph's placeholder values and evaluate the guards with args""" + code = self.produce_guards_expression(placeholders, ignore_static=ignore_static) + if code: + return self.evaluate_guards_expression(code, args) + return True + + def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]: + """ + Get a list of guards, but pruned so it only provides guards that + reference symints from the passed in input + """ + symints = { + s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) + } + guards = [ + g for g in self.guards if all(s in symints for s in g.expr.free_symbols) + ] + return guards + + def bind_symbols( + self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor] + ) -> dict[sympy.Symbol, int]: + """ + Given a paired list of placeholders (fake tensors with + symbolic sizes) and concrete arguments (regular tensors + with real sizes), returns a dictionary mapping each + symbol to its real value. So for example, if you + have a placeholder with size (s0, s1), binding + (2, 4) to it will give you {s0: 2, s1: 4}. This is + not guaranteed to bind ALL symbols in the ShapeEnv; + we can't bind a symbol if it doesn't occur in any placeholder, + and symbols that already have replacements won't get bindings. + + This is a little duplicative with evaluate_guards but + it's different enough that it seemed cleanest to make + another copy. This assumes the guards are already checked, + though if it's cheap we'll check for shenanigans + """ + bindings: dict[sympy.Symbol, int] = {} + + def bind_symint(arg: object, val: object) -> None: + if isinstance(val, SymInt): + assert isinstance(arg, int) + s = val.node.expr + + if isinstance(s, sympy.Symbol): + if s in bindings: + assert bindings[s] == arg, f"{bindings[s]} != {arg}" + else: + bindings[s] = arg + elif isinstance(-s, sympy.Symbol): + if -s in bindings: + assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}" + else: + bindings[-s] = -arg + + for t, arg in zip(placeholders, args): + if t is None: + continue + if isinstance(t, SymInt): + bind_symint(arg, t) + continue + assert isinstance(t, torch.Tensor) + for i, s in enumerate(t.size()): + bind_symint(arg.size(i), s) + for i, s in enumerate(t.stride()): + bind_symint(arg.stride(i), s) + bind_symint(arg.storage_offset(), t.storage_offset()) + + return bindings + + def get_nontrivial_guards(self) -> list[SympyBoolean]: + """Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" + return [ + self.simplify(guard.expr) + for guard in self.guards + if self._maybe_evaluate_static( + guard.expr, axioms=(), size_oblivious=guard.size_oblivious + ) + is None + ] + + def format_guards(self, verbose: bool = False) -> str: + """Format this shape env's guard expressions with optional traceback info if verbose""" + + return "\n".join( + f" - {guard.expr}{' ' + str(guard.sloc) if verbose else ''}" + for guard in self.guards + ) + + def bound_sympy( + self, expr: sympy.Expr, size_oblivious: bool = False + ) -> ValueRanges: + """Given a sympy expression, computes a ValueRanges bound for what values it can be""" + # TODO: maybe it's guaranteed x in is var_to_range? + var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} + if size_oblivious: + # Clamp values of size-like variables + # NB: discarding the old upper bound in intentional, per + # https://github.com/pytorch/pytorch/pull/123675 + for x in self.size_like & var_to_range.keys(): + if var_to_range[x] is not None: + # NB: do NOT set upper to 2 ** 48, we're using this solely + # to determine if we can do size-like replacement, the + # upper bound is irrelevant here + var_to_range[x] = ValueRanges(2, int_oo) + return bound_sympy(expr, var_to_range) # type: ignore[arg-type] + + @_lru_cache + def get_axioms( + self, + symbols: Optional[tuple[sympy.Symbol]] = None, + compute_hint: bool = False, + ) -> tuple[SympyBoolean, ...]: + """ + Given the symbols in an expression, it returns all the runtime asserts that have those symbols + concatenated with all the guards. + If symbols is None, it returns all the runtime asserts (and all the guards) + """ + if symbols is None: + runtime_asserts = ( + r.expr for rs in self.deferred_runtime_asserts.values() for r in rs + ) + else: + runtime_asserts = ( + r.expr + for s in symbols + if s not in self.var_to_val + for r in self.deferred_runtime_asserts.get(s, ()) + ) + guards: Iterator[SympyBoolean] = (g.expr for g in self.guards) + axioms: Iterator[SympyBoolean] = itertools.chain(guards, runtime_asserts) + if compute_hint: + axioms = ( + canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms + ) + return tuple(dict.fromkeys(axioms).keys()) + + @lru_cache(None) + def get_implications( + self, e: SympyBoolean + ) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]: + """Given a expression, it returns a list of predicates that follow from it""" + equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {} + + def add_expr(expr: SympyBoolean) -> None: + expr = canonicalize_bool_expr(expr) + if isinstance(expr, (sympy.Eq, sympy.Ne)): + # No need to canonicalize + # TODO We could further canonicalize Eq ordering the lhs and rhs somehow + # With this, we could remove the need for the commutativity part + opposite = sympy.Eq if isinstance(expr, sympy.Ne) else sympy.Ne + # Commutativity of == and != + equiv[type(expr)(expr.lhs, expr.rhs, evaluate=False)] = sympy.true + equiv[type(expr)(expr.rhs, expr.lhs, evaluate=False)] = sympy.true + equiv[opposite(expr.lhs, expr.rhs, evaluate=False)] = sympy.false + equiv[opposite(expr.rhs, expr.lhs, evaluate=False)] = sympy.false + else: + # Expr and negation + equiv[expr] = sympy.true + # we do not pass evaluate=False like others on purpose here! + # we want not(a=b and not ~(a Optional[sympy.Basic]: + """ + Tries to evaluate expr without introducing guards + + If unbacked_only == True, then we only do substitutions on + unbacked SymInts (leaving regular hinted integers alone). This could + result in an expression that still contains backed SymInts, which you + could then potentially guard on. + + Use compute_hint == True if you are trying to compute a non-binding + hint for the particular hint values of backed and unbacked SymInts, + e.g., if s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. + """ + + # axioms with compute hint NYE + assert not compute_hint or not axioms + expr = self.simplify(expr, size_oblivious) + + if compute_hint: + expr = expr.xreplace(self.var_to_val).xreplace(self.unbacked_var_to_val) + + expr = canonicalize_bool_expr(expr) + + def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: + if not self._resimplify_floor_div_axioms: + return + self._resimplify_floor_div_axioms = False + new_items = {} + for k, v in axioms.items(): + # A FloorDiv in implications could have became CleanDiv at this point, due to new facts + # to the shapeEnv. This handles such issue but its not ideal. This is the only expression + # simplification that depends on the global state of shape env. + # TODO try to get rid of CleanDiv since it breaks the invariant thats simplifications of sympy + # expressions only depend on the expression itself. + if k.has(FloorDiv): + new_items.update({self.simplify(k): v}) + axioms.update(new_items) + + # Pattern matching + if axioms is None: + resimplify_floor_div(self.axioms) + subst = self.axioms + else: + subst = {} + for e in axioms: + if e.free_symbols.issubset(expr.free_symbols): + subst.update(dict(self.get_implications(self.simplify(e)))) + + resimplify_floor_div(subst) + + expr = expr.xreplace(subst) + # TODO: compute hint might have gotten broken here + + fs = expr.free_symbols + + if not fs and (expr.is_number or expr.is_Boolean): + return expr + + if var_to_range is None: + var_ranges = self.var_to_range + else: + var_ranges = dict(var_to_range) + + symbol_info = tuple( + _SymbolInfo( + s, + var_ranges.get(s), + self.var_to_val.get(s), + s in self.size_like, + ) + for s in sorted(fs, key=str) # TODO: speed up sort? + ) + + r = _maybe_evaluate_static_worker( + expr, symbol_info, unbacked_only, size_oblivious + ) + return r + + @_lru_cache + def replace(self, expr: _SympyT) -> _SympyT: + """ + Apply symbol replacements to any symbols in the given expression. + """ + replacements = {} + for s in expr.free_symbols: + r = self._find(s) + + # Micro-optimization: only do replacements if r and s are different + # Otherwise, xreplace is not a no-op and will trigger expensive + # assumption queries if expr has a relational node. + if not r.is_Symbol or r != s: + replacements[s] = r + if replacements: + return safe_expand(expr.xreplace(replacements)) + else: + return expr + + @_lru_cache + def _update_divisible(self) -> None: + new_divisible = set() + for k in self.divisible: + res = self.replace(k) + if not res.is_number: + new_divisible.add(k) + + self.divisible = new_divisible + self._update_version_counter() + + @_lru_cache + def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: + """Use known constraints and replacements to simplify the given expr""" + expr = safe_expand(expr) + expr = self.replace(expr) + + if size_oblivious and (expr.has(Max) or expr.has(Min)): # type: ignore[has-type] + min_max_replacements = {} + for atom in (*expr.atoms(Max), *expr.atoms(Min)): # type: ignore[has-type] + if len(atom.args) > 2: + continue + a, b = atom.args + if b == 1 or b == 0: + a, b = b, a + if a == 1 or a == 0: + vr = self.bound_sympy(b, size_oblivious=True) + if vr.lower >= a: + min_max_replacements[atom] = b if atom.func is Max else a + elif vr.upper <= a: + min_max_replacements[atom] = a if atom.func is Max else b + if min_max_replacements: + expr = expr.xreplace(min_max_replacements) + + if expr.has(TruncToInt): + trunc_replacements = {} + for atom in expr.atoms(TruncToInt): + if isinstance(atom.args[0], IntTrueDiv): + base, divisor = atom.args[0].args + if base % divisor == 0: + trunc_replacements[atom] = base // divisor + if trunc_replacements: + expr = expr.xreplace(trunc_replacements) + + # TODO it would seem that this pass is not necessary given the + # below replacement of // with /, but for nested FloorDivs + # the non-recursive replacement doesn't work, and + # recursive makes it hard to look up divisibility, + # because existing divisibility info has FloorDiv in it, not / + # for now just do a separate pass to catch common nested case + if expr.has(FloorDiv): + self._update_divisible() + div_replacements = {} + for atom in expr.atoms(FloorDiv): + base, divisor = atom.args + if isinstance(divisor, FloorDiv): + base1, divisor1 = divisor.args + if ( + self.replace(Mod(base, divisor)) in self.divisible + and base == base1 + and self.replace(Mod(base1, divisor1)) in self.divisible + ): + div_replacements[atom] = divisor1 + if div_replacements: + expr = expr.xreplace(div_replacements) + expr = safe_expand(expr) + if expr.has(FloorDiv): + div_replacements = {} + pows = expr.atoms(sympy.Pow) + rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) + for fd in expr.atoms(FloorDiv): + base, divisor = fd.args + if self.replace(Mod(base, divisor)) in self.divisible: + div_replacements[fd] = CleanDiv(base, divisor) + if div_replacements: + new_expr = expr.xreplace(div_replacements) + new_expr = safe_expand(new_expr) + new_pows = new_expr.atoms(sympy.Pow) + new_rationals = new_expr.atoms(sympy.Rational).difference( + new_expr.atoms(sympy.Integer) + ) + # divisions simplified away + if new_pows.issubset(pows) and new_rationals.issubset(rationals): + expr = new_expr + return expr + + # TODO: overload for allow_none literal + @lru_cache(256) + def size_hint( + self, expr: sympy.Basic, *, allow_none: bool = False + ) -> Optional[sympy.Basic]: + """ + Gets a size hint for a given expression from the underlying shapes we had. + Does not introduce a guard, so only use this when you can guarantee that + your code is still valid for arbitrary shapes (such as optimization decisions) + """ + result_expr = safe_expand(expr).xreplace(self.var_to_val) + if not result_expr.is_number: + from torch.utils._sympy.singleton_int import SingletonInt + + if isinstance(result_expr, SingletonInt): + return None + r = self._maybe_evaluate_static(result_expr, compute_hint=True) + if r is not None: + return r + if allow_none: + return None + + if self.oblivious_var_to_val: + # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + correct_hint = result_expr.xreplace(self.oblivious_var_to_val) + counterfactual_hint = result_expr.xreplace( + {k: max(v, 2) for k, v in self.oblivious_var_to_val.items()} + ) + if ( + not correct_hint.free_symbols + and not counterfactual_hint.free_symbols + ): + if correct_hint == counterfactual_hint: + log.info("oblivious_size hit %s -> %s", expr, correct_hint) + return correct_hint + else: + log.info( + "oblivious_size counterfactual failed %s -> %s != %s", + expr, + correct_hint, + counterfactual_hint, + ) + else: + log.info( + "oblivious_size miss %s -> %s (counterfactual: %s)", + expr, + correct_hint, + counterfactual_hint, + ) + + if self.unbacked_var_to_val: + unsound_expr = result_expr.xreplace(self.unbacked_var_to_val) + if not unsound_expr.free_symbols: + log.warning( + "propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr + ) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(expr), + "result": repr(unsound_expr), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + self.guard_or_defer_runtime_assert( + sympy.Eq(result_expr, unsound_expr), + f"propagate_real_tensors: {result_expr} == {unsound_expr}", + ) + return unsound_expr + + raise self._make_data_dependent_error(result_expr, expr) + return result_expr + + # NB: keep in sync with size_hint + @lru_cache(256) + def has_hint(self, expr: sympy.Expr) -> bool: + result_expr = safe_expand(expr).xreplace(self.var_to_val) + return ( + result_expr.is_number + or self._maybe_evaluate_static(result_expr) is not None + ) + + def _make_data_dependent_error( + self, + expr: sympy.Basic, + unhinted_expr: sympy.Basic, + *, + size_oblivious_result: Optional[sympy.Basic] = None, + expr_sym_node_id: Optional[int] = None, + ) -> GuardOnDataDependentSymNode: + # TODO: in a Dynamo context, having user code, and having the + # name of the local, will be much better + size_like_symbols = [] + for s in expr.free_symbols: + stacktrace = "".join(self.var_to_stack[s].format()) + self.log.debug( + "Data dependent variable '%s' allocated at:\n%s", s, stacktrace + ) + if s in self.size_like: + size_like_symbols.append(s) + size_oblivious_result_msg = "" + if size_oblivious_result is not None: + size_oblivious_result_msg = ( + f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n" + "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n" + ) + sloc, maybe_extra_debug = self._get_stack_summary(True) + if expr.is_integer: # type: ignore[attr-defined] + desc = ( + "Could not extract specialized integer from data-dependent expression" + ) + else: + desc = "Could not guard on data-dependent expression" + msg = ( + f"{desc} {expr} (unhinted: {unhinted_expr}). " + f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n" + f"{size_oblivious_result_msg}" + f"Caused by: {sloc}\n" + 'For more information, run with TORCH_LOGS="dynamic"\n' + "For extended logs when we create symbols, also add " + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n' + "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" + "For more debugging help, see " + "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" + + maybe_extra_debug + # TODO: Help text about how to use our runtime tests to fix this + # problem + ) + + dtrace_structured( + "guard_on_data_dependent_error", + metadata_fn=lambda: { + "expr": repr(expr), + "unhinted_expr": repr(unhinted_expr), + "expr_id": self._expr_sym_node_id, + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + return GuardOnDataDependentSymNode(expr, msg) + + def _update_var_to_range( + self, + symbol: sympy.Symbol, + vr: ValueRanges, + vr_sloc: Optional[ValueRangesSLoc] = None, + *, + is_constraint: bool = False, + ) -> None: + lower, upper = vr.lower, vr.upper + + # If we have a size-like unbacked SymInt, refuse to refine the range to be + # less than two. This is because when we intersect this range + # with [2, inf] for size oblivious tests, the range would be + # unsatisfiable. In other words, once you have a size-like + # unbacked SymInt, we can never learn that it is exactly zero or one, + # because we would now give inconsistent results for all size + # oblivous tests! + if upper < 2 and symbol in self.size_like: + vr = ValueRanges(lower, 2) + + # Updates the range and the guards corresponding to each bound of the symbol. + if symbol not in self.var_to_range: + self.log.debug("_update_var_to_range %s = %s (new)", symbol, vr) + self.var_to_range[symbol] = vr + if vr_sloc is None: + sloc = self._get_sloc() + vr_sloc = ValueRangesSLoc(sloc, sloc) + self.var_to_range_sloc[symbol] = vr_sloc + else: + old = self.var_to_range[symbol] + new = old & vr + if new != old: + if vr_sloc is None: + sloc = self._get_sloc() + vr_sloc = ValueRangesSLoc(sloc, sloc) + if new.lower != old.lower: + self.var_to_range_sloc[symbol].lower = vr_sloc.lower + if new.upper != old.upper: + self.var_to_range_sloc[symbol].upper = vr_sloc.upper + self.var_to_range[symbol] = new + self.log.debug("_update_var_to_range %s = %s (update)", symbol, new) + + if (v := self.var_to_val.get(symbol)) is not None: + r = self.var_to_range[symbol] + if v not in r: + # For constraint failure, delay this for later + # TODO: Rework all of this, the constraint logic is very + # duplicative with regular reasoning + if not is_constraint: + assert v in r, f"{v} not in {r}" + + def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: + """ + Adds or updates a replacement for a symbol. + Use this instead of `self.replacements[a] = tgt`. + """ + if tgt == self.replacements.get(a, None): + return + + if a in tgt.free_symbols: + return + + # Precondition: a == tgt + assert isinstance(a, sympy.Symbol) + + if ( + self.allow_complex_guards_as_runtime_asserts + and not _is_supported_equivalence(tgt) + ): + return # continuing leads to placeholder shapes having complex expressions that we can't resolve + + # Handles nested tensor symbolic variables which don't have + # var_to_range bounds + tgt_bound = None + if a in self.var_to_range: + src_bound = self.var_to_range[a] + + # First, refine the value range of a based on the computed value range + # of tgt. This is always OK to do, even if we decide not to do the + # substitution in the end. This might be a no-op, if a already has + # a tighter bound + tgt_bound = self.bound_sympy(tgt) + self._update_var_to_range(a, tgt_bound) + + # Next, check if we can update the range of free symbols in tgt + # based on the range in a. But only do it if: + # - the source bound non-trivially improves over what we get out of + # the existing bounds. + # - the replacement is univariate and we can invert the tgt expression + if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: + b = next(iter(tgt.free_symbols)) + # Try to invert the equality + r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) + if r is not None: + self.log.debug( + "set_replacement: solve for %s in %s == %s gives %s", + b, + a, + tgt, + r, + ) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges( + CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper) + ) + self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) + tgt_bound = self.bound_sympy(tgt) + assert tgt_bound.issubset(src_bound), ( + f"{tgt_bound=} not a subset of {src_bound=}" + ) + + # TODO: Should we propagate size-like-ness? + # + # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 + # to become size-like. + # + # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T + # propagate in this case, because what if u0 == 0, then u1 is negative + # and clearly isn't a size. So, at minimum, any f(x) whose value + # range isn't [0, inf] given x in [0, inf] cannot propagate + # size-like-ness. But there are many situations where you could + # imagine u1 is going to be size-like and actually you just didn't + # have a refined enough value range on u0. Since even innocuous + # looking arithmetic operations can destroy size-like-ness, it's + # best to not propagate it at all and force the user to annotate it + # as necessary. + # + # Compromise: we preserve size-like-ness only for exact equality + # and nothing else. + if a in self.size_like and isinstance(tgt, sympy.Symbol): + self.size_like.add(tgt) + elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: + self.size_like.add(a) + + # Now, decide if we will do the substitution. + # + # - If the source has a non-trivial range, only substitute if + # we preserve this range. Note that we may have propagated + # the src_range to free variables in tgt when tgt is univariate + # and we could find an inverse, which helps us achieve this. + # This ensures we never "forget" about user defined ranges, + # even if they end up being defined on composite formulas + # like s0 + s1. + # + # - If the variable is unbacked, only substitute if the substitution + # would preserve the bounds also under size-like-ness conditions. + + if not tgt_bound.issubset(src_bound): + self.log.debug( + "skipped set_replacement %s = %s (%s) [%s not subset of %s]", + a, + tgt, + msg, + tgt_bound, + src_bound, + ) + return + elif a in self.size_like: + tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) + src_bound_so = self.bound_sympy(a, size_oblivious=True) + if not tgt_bound_so.issubset(src_bound_so): + self.log.debug( + "skipped set_replacement %s = %s (%s) " + "[%s not subset of %s (size-oblivious conditions)]", + a, + tgt, + msg, + tgt_bound_so, + src_bound_so, + ) + return + + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + "user_stack": ( + structured.from_traceback(user_tb) if user_tb else None + ), + }, + ) + + for source in self.var_to_sources.get(a, []): + if user_tb: + self.user_specialization_stacks[source] = user_tb + self.framework_specialization_stacks[source] = ( + CapturedTraceback.extract(cpp=True) + ) + + if config.print_specializations: + self.log.warning( + "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt + ) + self.log.debug("SPECIALIZATION", stack_info=True) + log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) + self.replacements[a] = tgt + # NB: the replacement may get refined, but the user will find the + # FIRST one most useful (TODO: Maybe we could consider tracking all of + # them) + if a not in self.replacements_slocs: + self.replacements_slocs[a] = self._get_sloc() + self._update_version_counter() + + # When specializing 'a == tgt', the equality should be also conveyed to + # Z3, in case an expression uses 'a'. + self._add_target_expr(sympy.Eq(a, tgt, evaluate=False)) + + def _add_divisible(self, expr: sympy.Expr) -> None: + self.divisible.add(expr) + self._update_version_counter() + + @_lru_cache + @record_shapeenv_event() + def _find(self, a: sympy.Symbol) -> sympy.Expr: + """ + Implements a DSU-like algorithm to find the variable that represents a + Also handles transitive non-identity replacements. + + a: b + c + c: d + """ + if a not in self.replacements: + return a + res = self.replacements[a] + cur_replace = {s: self._find(s) for s in res.free_symbols} + replaced, changed = self.replacements[a]._xreplace(cur_replace) + if changed: + self._set_replacement(a, replaced, "find") + return self.replacements[a] + + @lru_cache(256) + def _maybe_guard_rel(self, expr: sympy.Expr) -> None: + """ + The relational guard is guarded to be true. Use this information to + simplify shapes (i.e. a == b or a % 5 == 0) + """ + if isinstance(expr, sympy.And): + for arg in expr.args: + self._maybe_guard_rel(arg) + return + elif not isinstance(expr, sympy.Rel): + log.warning( + "_maybe_guard_rel() was called on non-relation expression %s", expr + ) + return + + # A good example of what goes wrong if you don't do this is + # python test/functorch/test_aotdispatch.py -k + # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32 + if isinstance(expr, sympy.Ne): + return + + free = list(expr.free_symbols) + + assert len(free) > 0, ( + f"The expression should not be static by this point: {expr}" + ) + # In case of really gnarly expression, we don't blow up + if len(free) > 5: + return + + # Prioritize unbacked symints for solving by ordering them last. + # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). + # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) + # Prefer to simplify out symbols with ephemeral sources. + def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]: + has_only_ephemeral_sources = x in self.var_to_sources and all( + s.is_ephemeral() for s in self.var_to_sources[x] + ) + # NB: size_hint is int, not sympy.Expr, do not use int_oo here + hint_size = self.size_hint(x, allow_none=True) + if hint_size is None: + size = sys.maxsize + elif symbol_is_type(x, SymT.SIZE): + assert isinstance(hint_size, sympy.Expr) + size = int(hint_size) + else: + size = sys.maxsize + name = x.name + # 1 puts ephemeral sourced symbols first when sorting in reverse + return (1 if has_only_ephemeral_sources else 0, size, name) + + free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined] + lhs = expr.lhs + rhs = expr.rhs + + self._refine_ranges(expr) + + # The rest of this stuff is for equality only + if not isinstance(expr, sympy.Eq): + return + + if not expr.has(Mod): + try: + floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv)) + if len(floor_div_atoms) > 0 and any( + a.divisor != 1 for a in floor_div_atoms + ): + raise NotImplementedError + + # Never replace unbacked symbols with other unbacked symbols. + # This is error prone because you can cause references to + # unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. It's pretty inconvenient to setup control + # dependencies for substitutions, so ban it entirely. + def trivial_solve(lhs: sympy.Expr, rhs: sympy.Expr) -> bool: + if isinstance(lhs, sympy.Symbol): + if free_unbacked_symbols(lhs) and not free_unbacked_symbols( + rhs + ): + return True + if symbol_is_type(lhs, SymT.FLOAT): + return True + # TODO: Maybe trivial solutions for int should also be + # done? + return False + + # short-circuit when no solving is needed + if trivial_solve(lhs, rhs): + self._set_replacement(lhs, self._find(rhs), "trivial_lhs") + elif trivial_solve(rhs, lhs): + self._set_replacement(rhs, self._find(lhs), "trivial_rhs") + else: + r = try_solve(expr, free[0], floordiv_inequality=False) + if r is not None and all( + t.is_integer for t in sympy.preorder_traversal(r[1]) + ): + new_var = self._find(r[1]) + ok = len(free_unbacked_symbols(new_var)) == 0 + if ok: + self._set_replacement(free[0], new_var, "solve") + except NotImplementedError: + pass + if expr.has(Mod): + mod_expr = next(iter(expr.atoms(Mod))) + try: + r = try_solve(expr, mod_expr, floordiv_inequality=False) + if r is not None and r[1] == 0: + self._add_divisible(mod_expr) + # This is a little bit of extra logic to make things like + # torch.empty(i0, q).view(c, -1, q) work out + p, q = mod_expr.args + if ( + isinstance(q, sympy.Number) + and isinstance(p, sympy.Mul) + and len(p.args) == 2 + ): + c, i0 = p.args + # Given Mod(c * i0, q) == 0 + if ( + isinstance(c, sympy.Number) + and isinstance(i0, sympy.Symbol) + and self.is_unbacked_symint(i0) + ): + # We have Mod(i0, q / c) == 0, which means we can + # rewrite i0 as (q / gcd(q, c)) * i1 + d = q / sympy.gcd(q, c) # TODO: CleanDiv? + i1 = self.create_unbacked_symint().node.expr + # Propagate the value ranges. It doesn't really + # matter if we use truediv or floordiv, because we + # have established divisibility. + self._update_var_to_range( + i1, + SymPyValueRangeAnalysis.floordiv( + self.var_to_range[i0], ValueRanges.wrap(d) + ), + ) + # Propagate hints (real tensor tracing) + if i0 in self.unbacked_var_to_val: + self.set_unbacked_var_to_val( + i1, self.unbacked_var_to_val[i0] // d + ) + # Propagate size-like-ness + if i0 in self.size_like: + self.size_like.add(i1) + self._set_replacement(i0, d * i1, "divisibility") + + except NotImplementedError: + pass + return + + # See: Note - On 0/1 specialization + def _default_value_range( + self, do_not_specialize_zero_one: bool = False + ) -> ValueRanges: + lower = 0 if (do_not_specialize_zero_one or not self.specialize_zero_one) else 2 + return ValueRanges(lower, int_oo) + + def _default_unspecified_value_range(self) -> ValueRanges: + return ValueRanges.unknown_int() + + @_lru_cache + def _simplify_floor_div(self, expr: sympy.Expr) -> sympy.Expr: + floor_divs = tuple(expr.atoms(FloorDiv)) + # we expect floor_divs to be exact, + # and thus add the guards for the exact floordivs, + # even if tracing doesn't require them otherwise + for fd in reversed(floor_divs): + base, divisor = fd.args + mod_expr = Mod(base, divisor) + eq_expr = sympy.Eq(mod_expr, 0) + # add necessary mod guards + self.evaluate_expr(eq_expr) + return self.simplify(expr) + + # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen + # and if so issue a warning + def _check_frozen(self, expr: sympy.Basic, concrete_val: sympy.Basic) -> None: + if self.frozen: + self.counter["ignored_backward_guard"] += 1 + signpost_event( + "dynamic", + "evaluate_expr_frozen", + { + **self.co_fields, + "ignored_guard": f"{expr} == {concrete_val}", + # no version = original state (this signpost is expected) + # version 2 = dynamic backwards is eagerly compiled + "version": 2, + }, + ) + log.info( + "Ignored guard %s == %s, this could result in accuracy problems", + expr, + concrete_val, + # only print stack trace when debug mode is on (e.g. TORCH_LOGS="dynamic") + stack_info=True if log.getEffectiveLevel() < logging.WARNING else False, + ) + + def _get_user_frame(self) -> Optional[types.FrameType]: + frame = inspect.currentframe() + while frame is not None: + if frame.f_code.co_filename not in uninteresting_files(): + return frame + frame = frame.f_back + return frame + + def _get_stack_summary( + self, is_debug: bool = False, framework_loc: Optional[str] = None + ) -> tuple[SLoc, str]: + floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc + if floc is None: + frame = self._get_user_frame() + try: + if frame is not None: + floc = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + finally: + del frame + + # NB: this stack is truncated, but it's fine because the main + # stack_info will give you the rest of the info you need + maybe_user_loc = None + user_tb = TracingContext.extract_stack() + if user_tb: + idx = len(user_tb) - 1 + while idx > 0 and user_tb[idx].filename in uninteresting_files(): + idx -= 1 + maybe_user_loc = format_frame(user_tb[idx], line=True) + + maybe_extra_debug = "" + if is_debug and user_tb: + maybe_extra_debug = ( + "\nUser Stack (most recent call last):\n" + + " (snipped, see stack below for prefix)\n" + + "".join(traceback.format_list(user_tb)) + ) + if is_debug and config.extended_debug_cpp: + cpp_stack = CapturedTraceback.extract(cpp=True) + maybe_extra_debug += "\nC++ stack trace:\n" + "".join(cpp_stack.format()) + elif is_debug: + maybe_extra_debug += ( + "\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1" + ) + + return SLoc(floc, maybe_user_loc), maybe_extra_debug + + # Pass in framework_loc to override the framework location info + def _get_sloc(self, framework_loc: Optional[str] = None) -> SLoc: + sloc, _ = self._get_stack_summary(framework_loc=framework_loc) + return sloc + + def _generate_unique_id(self, source_name: str) -> int: + attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100 + while attempt in self.unique_ids: + attempt += 1 + self.unique_ids.add(attempt) + return attempt + + def _find_frame_locals(self) -> _FrameLocalResult: + """ + Given the current user code frame, finds the relevant lines of code, + values of symbolic locals, and free symbols involved. + """ + frame_locals: dict[str, Any] = {} + frame_symbols: dict[str, str] = {} + + if ( + frame := _find_user_code_frame() + ) is None or frame.f_code.co_filename == "": + return _FrameLocalResult() + + # find bytecode instructions relevant to the frame + instructions = list(dis.Bytecode(frame.f_code)) + co_lines, offset = inspect.getsourcelines(frame.f_code) + start, end, cur = None, None, None + for i, instr in enumerate(instructions): + if instr.starts_line is not None: + cur = instr.starts_line + if cur != frame.f_lineno: + continue + if start is None: + start = end = i + else: + end = i + + if start is None or end is None: # no instructions found + return _FrameLocalResult() + + # track involved locals and free symbols + def go(x: Any) -> Optional[str]: + if isinstance(x, torch.Tensor): + for y in x.size(): + go(y) + for y in x.stride(): + go(y) + go(x.storage_offset()) + return ( + f"Tensor(shape: {x.size()}, " + f"stride: {x.stride()}, " + f"storage_offset: {x.storage_offset()})" + ) + elif isinstance(x, (SymBool, SymInt, SymFloat)): + for s in x.node.expr.free_symbols: + if str(s) in frame_symbols: # type: ignore[operator] + continue + if s in self.var_to_sources: + frame_symbols[str(s)] = self.var_to_sources[s][0].name() # type: ignore[assignment] + return str(x) + return None + + # go through instructions, seeing linenos & involved locals + last_lineno = frame.f_lineno + for instr in instructions[start : end + 1]: + if (lineno := instr.starts_line) is not None: + last_lineno = max(last_lineno, lineno) + if isinstance(instr.argval, str) and instr.argval in frame.f_locals: + flat_locals = pytree.tree_flatten(frame.f_locals[instr.argval])[0] + frame_locals[instr.argval] = [ + go(flat_local) for flat_local in flat_locals + ] + + # store LOC + locs = co_lines[frame.f_lineno - offset : last_lineno + 1 - offset] + if not locs: + return _FrameLocalResult() + + indent = len(locs[0]) - len(locs[0].lstrip()) + frame_loc = "".join([loc[indent:] for loc in locs]).strip() # type: ignore[assignment] + return _FrameLocalResult( + loc=frame_loc, locals=frame_locals, symbols=frame_symbols + ) + + def _log_guard(self, prefix: str, g: SympyBoolean, forcing_spec: bool) -> None: + dtrace_structured( + "guard_added", + metadata_fn=lambda: { + "expr": str(g), + "prefix": prefix, + "expr_node_id": self._expr_sym_node_id, + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + "symbol_to_sources": { + str(v): k + for k, v in self.source_to_var.items() + if v in g.free_symbols + }, + "frame_locals": asdict(self._find_frame_locals()), + }, + ) + trace_structured( + "guard_added_fast", + metadata_fn=lambda: { + "expr": str(g), + "user_stack": structured.from_traceback(TracingContext.extract_stack()), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + if self.log.isEnabledFor(logging.INFO): + str_g = str(g) + is_debug = ( + config.extended_debug_guard_added is not None + and str_g == config.extended_debug_guard_added + ) + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + maybe_more_info = "" + if not is_debug: + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"' + ) + self.log.info( + "%s %s [guard added] %s%s%s", + prefix if not forcing_spec else f"{prefix} (forcing_spec)", + str_g, + sloc, + maybe_more_info, + maybe_extra_debug, + stack_info=is_debug, + ) + + # A local variable to evaluate_expr stored in the class to avoid + # using it for the lru_cache that is on top of it since it does + # not effect the results. When needed its read directly. + _expr_sym_node_id: Optional[int] = None + + def evaluate_sym_node( + self, + sym_node: SymNode, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + ) -> sympy.Basic: + """ + Given a a SymNode, evaluates sym_node.expr, adding guards if necessary. + """ + + self._expr_sym_node_id = id(sym_node) + return self.evaluate_expr( + sym_node.expr, + sym_node.hint, + sym_node.fx_node, + size_oblivious, + fallback_value=fallback_value, + ) + + def _is_python_assert(self) -> bool: + # Check if this boolean is used in an assertion, bytecode pattern for + # assertions is pretty stable for Python 3.7--3.13, ported with minimal + # changes from torch/fx/proxy.py + # Bytecode pattern for `assert` statements: + # TO_BOOL / COMPARE_OP # Only for Python >= 3.13 + # POP_JUMP_IF_TRUE + # LOAD_ASSERTION_ERROR + # RAISE_VARARGS + frame = self._get_user_frame() + assert frame is not None + + insts = list(dis.get_instructions(frame.f_code)) + if sys.version_info >= (3, 11): + # For Python >= 3.11, instructions can be 2-4 bytes long. + from bisect import bisect_left + + cur = bisect_left(insts, frame.f_lasti, key=lambda x: x.offset) + else: + # For Python <= 3.10, instructions are always 2 bytes. + cur = frame.f_lasti // 2 + + if sys.version_info >= (3, 13): + if insts[cur].opname in ("TO_BOOL", "COMPARE_OP"): + # Peek 1 instruction further. + cur += 1 + inst = insts[cur] + + if inst.opname == "POP_JUMP_IF_TRUE" and inst.arg is not None: + first = insts[cur + 1] + + starts_with_assert = ( + first.opname == "LOAD_GLOBAL" + and first.argval == "AssertionError" + or first.opname == "LOAD_ASSERTION_ERROR" + ) + if starts_with_assert and insts[cur + 2].opname == "RAISE_VARARGS": + return True + return False + + def _log_real_tensor_propagation( + self, orig_expr: sympy.Basic, unsound_result: sympy.Basic + ) -> None: + log.warning( + "propagate_real_tensors evaluate_expr(%s) -> %s", + orig_expr, + unsound_result, + ) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(orig_expr), + "result": repr(unsound_result), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + dtrace_structured( + "propagate_real_tensors_provenance", + metadata_fn=lambda: { + "expr": repr(orig_expr), + "result": repr(unsound_result), + "expr_node_id": self._expr_sym_node_id, + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + "symbol_to_sources": { + str(v): k + for k, v in self.source_to_var.items() + if v in orig_expr.free_symbols + }, + "frame_locals": asdict(self._find_frame_locals()), + }, + ) + + def evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]] = None, + fx_node: Optional[torch.fx.Node] = None, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + *, + forcing_spec: bool = False, + ) -> sympy.Basic: + """ + Given an expression, evaluates it, adding guards if necessary + When fallback_value is not None the function return fallback_value instead of failing with data dependent error. + """ + + # Add extra state that evaluate_expr() depends on. + suppress_guards_tls = ShapeEnv._suppress_guards_tls() + return self._inner_evaluate_expr( + orig_expr, + hint, + fx_node, + size_oblivious, + forcing_spec, + suppress_guards_tls, + fallback_value, + ) + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True, name="evaluate_expr") + def _inner_evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]], + fx_node: Optional[torch.fx.Node], + size_oblivious: bool, + forcing_spec: bool, + _suppress_guards_tls: bool, + fallback_value: Optional[bool] = None, + ) -> sympy.Basic: + try: + return self._evaluate_expr( + orig_expr, + hint, + fx_node, + size_oblivious, + fallback_value, + forcing_spec=forcing_spec, + ) + except Exception as e: + if isinstance(e, GuardOnDataDependentSymNode): + pass + else: + self.log.warning( + "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s", + orig_expr, + hint, + size_oblivious, + forcing_spec, + ) + raise + + def _log_suppressed_dde(self, a: SymBool, assumed_value: bool) -> None: + sloc, extra = self._get_stack_summary(True) + log.info( + "could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s", + a, + assumed_value, + sloc, + extra, + ) + + def _evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[bool, int, float]] = None, + fx_node: Optional[torch.fx.Node] = None, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + *, + forcing_spec: bool = False, + ) -> sympy.Basic: + # TODO: split conjunctions and evaluate them separately + + if isinstance( + orig_expr, + (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse), + ): + return orig_expr + + # Don't track this one. (Because this cache is inside this function the + # cache only lasts for the invocation of this function call) + @functools.cache + def compute_concrete_val() -> sympy.Basic: + if hint is None: + # This is only ever called for expressions WITHOUT unbacked + # symbols + r = self.size_hint(orig_expr) + assert r is not None + return r + else: + return sympy.sympify(hint) + + concrete_val: Optional[sympy.Basic] + + # Check if: + # 1. 'translation_validation' is set + # 2. the corresponding 'fx_node' is not 'None' + # 3. the guard should not be suppressed + # 4. the guard doesn't contain backed symfloat symbols + # since z3 can't handle floats + # 5. fallback_value is none. + # If all of the above check, we create an FX node representing the + # actual expression to be guarded. + node = None + fresh = False + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + and not size_oblivious + and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols) + and fallback_value is None + ): + # TODO: does this even worked with unbacked :think: + concrete_val = compute_concrete_val() + if concrete_val is sympy.true: + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + elif concrete_val is sympy.false: + neg, _ = self._create_fx_call_function(operator.not_, (fx_node,)) + node, fresh = self._create_fx_call_function(torch._assert, (neg,)) + else: + eql, _ = self._create_fx_call_function( + operator.eq, (fx_node, concrete_val) + ) + node, fresh = self._create_fx_call_function(torch._assert, (eql,)) + + assert node is not None + # If this is a fresh node, we have to remember the event index that + # corresponds to this assertion node. + # Reason: so that, given an assertion node, we can replay the ShapeEnv + # events until the point where this assertion node was freshly created. + if fresh: + self._add_fx_node_metadata(node) + + # After creating the FX node corresponding to orig_expr, we must make sure that + # no error will be raised until the end of this function. + # + # Reason: the translation validation may become invalid otherwise. + # + # If an error is raised before the end of this function, we remove the FX node + # inserted, and re-raise the error. + guard = None + + try: + if orig_expr.is_number: + self.log.debug("eval %s [trivial]", orig_expr) + if hint is not None: + if isinstance(hint, bool): + assert orig_expr == hint, f"{orig_expr} != {hint}" + else: + assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}" + return orig_expr + + expr = orig_expr + + static_expr = self._maybe_evaluate_static( + expr, size_oblivious=size_oblivious + ) + if static_expr is not None: + self.log.debug( + "eval %s == %s [statically known]", + ( + f"size_oblivious({orig_expr})" + if size_oblivious + else size_oblivious + ), + static_expr, + ) + if ( + not size_oblivious + and config.backed_size_oblivious + and hint is not None + ): + # TODO: maybe reconcile this with use of counterfactual hints + # in unbacked case + assert static_expr == hint, f"{static_expr} != {hint}" + return static_expr + + transmute_into_runtime_assert = False + + concrete_val = None + if not (expr.free_symbols <= self.var_to_val.keys()): + # TODO: dedupe this with _maybe_evaluate_static + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None + if not (new_expr.free_symbols <= self.var_to_val.keys()): + ok = False + + # fallback_value is set when guard_or_true or guard_or_false are used. + if not ok and fallback_value is not None: + self._log_suppressed_dde(orig_expr, fallback_value) + return fallback_value + + # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type. + # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + if ( + self.oblivious_var_to_val + and not ( + correct_hint := orig_expr.xreplace( + self.oblivious_var_to_val + ) + ).free_symbols + and not ( + counterfactual_hint := orig_expr.xreplace( + { + k: max(2, v) + for k, v in self.oblivious_var_to_val.items() + } + ) + ).free_symbols + and correct_hint == counterfactual_hint + ): + # TODO: better logging + log.info( + "oblivious_size %s -> %s (passed counterfactual)", + orig_expr, + correct_hint, + ) + concrete_val = correct_hint + # NB: do NOT transmute into runtime assert + ok = True + + # unbacked_var_to_val is not None iff propagate_real_tensors is on. + # if propagate_real_tensors is on, we check the example values to generate (unsound_result) + # and if they pass we add a runtime assertions and continue. + if ( + not ok + and self.unbacked_var_to_val + and not ( + unsound_result := orig_expr.xreplace( + self.unbacked_var_to_val + ).xreplace(self.var_to_val) + ).free_symbols + ): + self._log_real_tensor_propagation(orig_expr, unsound_result) + transmute_into_runtime_assert = True + concrete_val = unsound_result + ok = True + + # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion + # instead of failing. + if not ok and self.trace_asserts and self._is_python_assert(): + concrete_val = sympy.true + transmute_into_runtime_assert = True + ok = True + + if not ok: + size_oblivious_result = None + # compute size_oblivious_result to suggest it as a fix for the user if it works. + if not size_oblivious: + size_oblivious_result = self._maybe_evaluate_static( + expr, size_oblivious=True + ) + raise self._make_data_dependent_error( + expr.xreplace(self.var_to_val), + expr, + size_oblivious_result=size_oblivious_result, + expr_sym_node_id=self._expr_sym_node_id, + ) + else: + expr = new_expr + + if concrete_val is None: + concrete_val = compute_concrete_val() + self._check_frozen(expr, concrete_val) + + if ( + config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY + and isinstance(hint, bool) + and isinstance(expr, (sympy.Eq, sympy.Ne)) + ): + expr = sympy.Not(expr) + + # Turn this into a boolean expression, no longer need to consult + # concrete_val + if concrete_val is sympy.true: + g = cast(SympyBoolean, expr) + elif concrete_val is sympy.false: + g = sympy.Not(expr) + else: + g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] + + if transmute_into_runtime_assert: + self.guard_or_defer_runtime_assert( + g, f"propagate_real_tensors: {orig_expr} == {concrete_val}" + ) + return concrete_val + + if not self._suppress_guards_tls(): + self._log_guard("eval", g, forcing_spec=forcing_spec) + + # TODO: If we successfully eliminate a symbol via equality, it + # is not actually necessary to save a guard for the equality, + # as we will implicitly generate a guard when we match that + # input against the symbol. Probably the easiest way to + # implement this is to have maybe_guard_rel return a bool + # saying if it "subsumed" the guard (and therefore the guard + # is no longer necessary) + self._maybe_guard_rel(g) + + if not self.allow_complex_guards_as_runtime_asserts: + # at this point, we've evaluated the concrete expr value, and have + # flipped/negated the guard if necessary. Now we know what to guard + # or defer to runtime assert on. + guard = ShapeGuard( + g, self._get_sloc(), size_oblivious=size_oblivious + ) + self.guards.append(guard) + self.axioms.update(dict(self.get_implications(self.simplify(g)))) + else: + # it's fine to defer simple guards here without checking, + # the _maybe_guard_rel() call above will set replacements if possible, + # and so the result here will be statically known + self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") + else: + self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) + + except Exception: + if fresh: + self._remove_fx_node(node) + raise + + if not self._suppress_guards_tls(): + if guard is not None: # we might have deferred this to runtime assert + for s in g.free_symbols: + self.symbol_guard_counter[s] += 1 + # Forcing_spec to avoid infinite recursion + if ( + not forcing_spec + and config.symbol_guard_limit_before_specialize is not None + and self.symbol_guard_counter[s] + > config.symbol_guard_limit_before_specialize + ): + # Force specialization + self.log.info( + "symbol_guard_limit_before_specialize=%s exceeded on %s", + config.symbol_guard_limit_before_specialize, + s, + ) + self.evaluate_expr(s, forcing_spec=True) + + return concrete_val + + def cleanup(self) -> None: + """ + Break reference cycles. + + This destroys the stacks. If you really want to keep them, we + just need some way to break references on code objects. + """ + for s in self.var_to_stack.values(): + s.cleanup() + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + ra.stack.cleanup() + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True) + def guard_or_defer_runtime_assert( + self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None + ) -> bool: + """ + Adds a guard that orig_expr is True if we can or fall back to adding an assert + that is checked at runtime. + + Args: + orig_expr (sympy.Expr): Boolean expression to assert is true + msg (str): Message to display on assertion failure + fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding + to the expression, if applicable + """ + expr = orig_expr + + # TODO: split conjunctions and evaluate them separately + + static_expr = self._maybe_evaluate_static(expr) + if static_expr is not None: + self.log.debug( + "runtime_assert %s == %s [statically known]", orig_expr, static_expr + ) + # TODO: assert bool(static_expr) + return bool(static_expr) + + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None + if ( + not self.prefer_deferred_runtime_asserts_over_guards + and new_expr.free_symbols <= self.var_to_val.keys() + ): + # Do a normal guard + return self.evaluate_expr(new_expr, fx_node=fx_node) + # NB: Don't use new_expr as expr; it could contain gunk like shape0 + # which we don't want to guard on + + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + ): + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + assert node is not None + if fresh: + self._add_fx_node_metadata(node) + + if not self._suppress_guards_tls(): + self._log_guard("runtime_assert", orig_expr, forcing_spec=False) + # If you're here because of this assert, read Note [Backwards runtime asserts] + # in torch/_inductor/graph.py + if self.runtime_asserts_frozen: + log.debug("runtime_asserts_frozen but then got %s", expr) + self._check_frozen(expr, sympy.true) + # eliminate symbols on equality tests / refine ranges + self._maybe_guard_rel(expr) + + # canonicalise to remove equations that are trivially equal + orig_expr = expr + expr = canonicalize_bool_expr(expr) + stack = CapturedTraceback.extract(skip=1) + ra = RuntimeAssert(expr, msg, stack) + # TODO: Do this in a way that is less janky than int(s.name[1:]) + cands = sorted( + (s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), + key=lambda s: int(s.name[1:]), + ) + # Is None when prefer_deferred_runtime_asserts_over_guards=True + # and the guard in question has no unbacked SymInts in front + ix = cands[-1] if cands else None + self.deferred_runtime_asserts.setdefault(ix, []).append(ra) + self.axioms.update(dict(self.get_implications(self.simplify(expr)))) + self.num_deferred_runtime_asserts += 1 + self._update_version_counter() + else: + self._log_guard( + "runtime_assert [guard suppressed]", orig_expr, forcing_spec=False + ) + + return True + + # Refines the ranges of the variables present in 'guard'. + # + # This function tries to refine the range of the variables inside + # 'guard' by reasoning about it. Specifically, when 'guard' is a + # 'sympy.Relational' operation. + # + # It does mainly 3 things: + # 1. Tries to isolate a variable in the left-hand side + # 2. Compute the value range of the right-hand side + # 3. Update the value range of the variable, if better + def _refine_ranges(self, expr: SympyBoolean) -> None: + expr = self.simplify(expr) + + for symbol in expr.free_symbols: + assert isinstance(symbol, sympy.Symbol) + + if isinstance(self.var_to_val.get(symbol, None), SingletonInt): + # Skip var_to_range logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + + r = try_solve(expr, symbol) + + if r is None or not (symbol.is_integer and r[1].is_integer): + # Range refinement only supports integer symbols for now. + # There are lots of SymPy bugs when it comes to comparing + # reals and integers, so we skip that for now. + continue + + r_expr, rhs = r + vr = self.var_to_range[symbol] + lower, upper = vr.lower, vr.upper + + rhs_vr = bound_sympy(rhs, self.var_to_range) + + # Let's suppose that we have a preexisting range for x [0, 100]. + # Now, we issue a guard x > y, where the range for y is [50, 150]. + # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen, + # refining x to [51, 100], since x must be greater than y, but the lowest + # y could be is 50. + # + # sympy.Eq may update both lower and upper bounds. + # sympy.G{t,e} may update the lower bound, only. + # sympy.L{t,e} may update the upper bound, only. + if lower < rhs_vr.lower and isinstance( + r_expr, (sympy.Eq, sympy.Ge, sympy.Gt) + ): + # Strictly greater relations allow us to refine a bit more, since + # x < y implies that the lower bound for x is: y + 1. + lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt)) + if upper > rhs_vr.upper and isinstance( + r_expr, (sympy.Eq, sympy.Le, sympy.Lt) + ): + upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt)) + + # Do nothing if the new value range is no better than what we already have. + if vr == ValueRanges(lower, upper): + continue + + # Updates the range and the guards corresponding to each bound of the symbol. + self._update_var_to_range(symbol, ValueRanges(lower, upper)) + # If the range is refined to singleton, set replacement + if self.var_to_range[symbol].is_singleton(): + self._set_replacement( + symbol, + self.var_to_range[symbol].lower, + "range_refined_to_singleton", + ) + + # Clears the cache, since this update can change the result. + self._maybe_evaluate_static.cache_clear() + + @lru_cache(maxsize=None) + @record_shapeenv_event() + def constrain_symbol_range( + self, s: sympy.Symbol, compiler_min: int, compiler_max: int + ) -> None: + upd_vr = ValueRanges(compiler_min, compiler_max) + old_vr = self.var_to_range.get(s, ValueRanges.unknown()) + self._update_var_to_range(s, upd_vr) + if (new_vr := self.var_to_range[s]) != old_vr: + log.info( + "constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper + ) + + +def _is_int(expr: object) -> bool: + return isinstance(expr, SymInt) and expr.node.expr.is_number + + +# WARNING: This is legacy, DO NOT USE +def _is_dim_dynamic(t: torch.Tensor, d: int) -> bool: + return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices + + +class PropagateUnbackedSymInts(torch.fx.Interpreter): + def run_node(self, n: torch.fx.Node) -> Result: + """ + Run an FX node, propagating unbacked Symbol bindings to the new fake tensor + """ + from torch._guards import detect_fake_mode + + result = super().run_node(n) + rebind_unbacked(detect_fake_mode().shape_env, n, result) + return result + + +def _find_user_code_frame() -> Optional[types.FrameType]: + frame = inspect.currentframe() + while frame is not None: + if not frame.f_code.co_filename.startswith( + os.path.dirname(inspect.getfile(torch)) + os.path.sep + ): + break + frame = frame.f_back + return frame + + +def _blame_user_code(e: Exception, frame: types.FrameType) -> None: + frame_summary = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + msg = e.args[0] + msg += "\n\nThe following call raised this error:\n" + "".join( + traceback.StackSummary.from_list([frame_summary]).format() + ) + e.args = (msg,) + + +class _PythonMsgPrinter(PythonPrinter): + """ + Util printer that replaces sympy symbols with their source-level names + and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline + (i.e., as ==, !=, >, <). + """ + + def __init__(self, src_map: dict[str, list[str]]) -> None: + super().__init__() + self.src_map = src_map + + def _print_Symbol(self, sym: sympy.Symbol) -> str: + return self.src_map[sym.name][0] + + +def _is_non_negative_check(cond: sympy.Basic) -> Optional[str]: + """ + Check if a condition (SymPy expression) is checking for non-negative values (>= 0). + Returns the variable name if it's a non-negative check (>= 0), None otherwise. + """ + if isinstance(cond, sympy.Rel): + if cond.rel_op == ">=" and cond.rhs == 0: + return str(cond.lhs) + return None + + +def _suggest_torch_checks( + e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]] +) -> None: + """ + Enhances a GuardOnDataDependentSymNode error with suggested fixes using torch._check. + + This function analyzes the condition that caused the data-dependent error and generates + user-friendly suggestions for fixing it by adding appropriate torch._check calls. + It handles special cases like non-negative checks with specific recommendations. + + Args: + e: The GuardOnDataDependentSymNode error to enhance with suggestions + src_map: A mapping from symbol names to their corresponding source-level variable names + + Returns: + None. Modifies the error message in-place by updating e.args[0]. + """ + # extract the unresolved condition on unbacked symints in the error + cond = e.cond + diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map) + if diff: + log.warning("Unable to find user code corresponding to {%s}", diff) + return + printer = _PythonMsgPrinter(src_map) + msg = e.args[0] + msg += "\nTo fix the error, insert one of the following checks before this call:" + + not_cond_str = printer.doprint(sympy.Not(cond)) + var_name = _is_non_negative_check(cond) + + # suggested fixes to resolve `cond` are to tell the compiler to assume + # either `cond` or its negation (the user will need to select which) + suggested_fixes = [] + + if var_name: + suggested_fixes = [ + f"You can add either: torch._check_is_size({var_name}) or torch._check({var_name}>=0)" + f" Note: torch._check_is_size({var_name}) could prevent data dependent errors that" + + " happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning." + + " See documentation on guard_size_oblivious for more details:" + + " https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html", + f"torch._check({not_cond_str})", + ] + else: + suggested_fixes = [ + f"torch._check({printer.doprint(cond)})", + f"torch._check({not_cond_str})", + ] + + for i, fix in enumerate(suggested_fixes): + msg += f"\n {i + 1}. {fix}" + src_mapped = ", ".join( + f"`{s}` with {' or '.join(src_map[s])}" + for s in sorted(s.name for s in cond.free_symbols) + ) + msg += f"\n\n(These suggested fixes were derived by replacing {src_mapped} in {cond} and its negation.)" + e.args = (msg,) + + +def _suggest_fixes_for_data_dependent_error_non_strict( + e: GuardOnDataDependentSymNode, +) -> None: + """ + Given a raised data-dependent error, add the following to the error message: + 1. the closest user code location that raised the error; + 2. suggested fixes for the error in terms of live variables at that location. + """ + + # walk the stack up from the data-dependent error until a non-torch frame is found + frame = _find_user_code_frame() + if frame is not None: + # add frame info to error message + _blame_user_code(e, frame) + + # map symbol names reachable via frame locals to their source-level names + src_map = defaultdict(list) + for var, val in frame.f_locals.items(): + try: + tree_leaves_with_path = pytree.tree_leaves_with_path(val) + except ValueError: + log.warning( + "pytree.tree_leaves_with_path failed for value of type {%s} in local variable {%s}", + type(val), + var, + ) + continue + # figure out how to access any symbol inside `val` through `var` + for path, leaf in tree_leaves_with_path: + name = var + pytree.keystr(path) + if isinstance(leaf, torch.SymInt): + src_map[str(leaf.node.expr)].append(name) + elif isinstance(leaf, torch.Tensor): + for i, dim in enumerate(leaf.shape): + if isinstance(dim, torch.SymInt): + src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") + + # add suggested torch.check()s based on `src_map` to the error message + # replacing unbacked symints in the unresolved condition in the error + if isinstance(e.cond, sympy.logic.boolalg.Boolean): + _suggest_torch_checks(e, src_map) + + +@contextmanager +def _remove_effect_token_unbacked_bindings( + node: torch.fx.Node, +) -> Generator[None, None, None]: + """ + Temporarily modifies unbacked_bindings in a node's metadata by removing the first element + of each path, which corresponds to an effect token. + + This is used when processing nodes that have effect tokens as the first element in their + unbacked_bindings paths. The context manager ensures that the original bindings are + restored after the operation is complete. + + Args: + node: The FX node whose unbacked_bindings will be temporarily modified + + Yields: + None + """ + old_bindings = node.meta.get("unbacked_bindings", {}) + + # Remove the extra layer for effect token + new_bindings = {k: path[1:] if path else path for k, path in old_bindings.items()} + + node.meta["unbacked_bindings"] = new_bindings + + try: + yield + finally: + node.meta["unbacked_bindings"] = old_bindings + + +# This helper function is used in passes that insert runtime assertions in the graph. +# When accessing expressions representing input placeholders, we do not apply replacements +# since those inputs should be seen by assertions that use them to be inserted. The only replacement +# that we apply is unbacked renaming. +def _get_placeholder_expr(sym_node: SymNode) -> sympy.Expr: + shape_env = sym_node.shape_env + result = sym_node._expr + if result in shape_env.unbacked_renamings: + return shape_env.unbacked_renamings[result] + return result diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7db0e29d1d4f75c770562c65013c03817643f6b7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/__init__.py @@ -0,0 +1,4 @@ +# mypy: disable-error-code=attr-defined +from .core import reify, unify # noqa: F403 +from .more import unifiable # noqa: F403 +from .variable import isvar, Var, var, variables, vars # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py new file mode 100644 index 0000000000000000000000000000000000000000..e32f42c8968e8bc2efd7e4a8c711026ead7c569b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/core.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterator # type: ignore[import] +from functools import partial + +from .dispatch import dispatch +from .unification_tools import assoc # type: ignore[import] +from .utils import transitive_get as walk +from .variable import isvar + + +__all__ = ["reify", "unify"] + +############### +# Reification # +############### + + +@dispatch(Iterator, dict) +def _reify(t, s): + return map(partial(reify, s=s), t) + # return (reify(arg, s) for arg in t) + + +_reify + + +@dispatch(tuple, dict) # type: ignore[no-redef] +def _reify(t, s): + return tuple(reify(iter(t), s)) + + +_reify + + +@dispatch(list, dict) # type: ignore[no-redef] +def _reify(t, s): + return list(reify(iter(t), s)) + + +_reify + + +@dispatch(dict, dict) # type: ignore[no-redef] +def _reify(d, s): + return {k: reify(v, s) for k, v in d.items()} + + +_reify + + +@dispatch(object, dict) # type: ignore[no-redef] +def _reify(o, s): + return o # catch all, just return the object + + +def reify(e, s): + """Replace variables of expression with substitution + >>> # xdoctest: +SKIP + >>> x, y = var(), var() + >>> e = (1, x, (3, y)) + >>> s = {x: 2, y: 4} + >>> reify(e, s) + (1, 2, (3, 4)) + >>> e = {1: x, 3: (y, 5)} + >>> reify(e, s) + {1: 2, 3: (4, 5)} + """ + if isvar(e): + return reify(s[e], s) if e in s else e + return _reify(e, s) + + +############### +# Unification # +############### + +seq = tuple, list, Iterator + + +@dispatch(seq, seq, dict) +def _unify(u, v, s): + if len(u) != len(v): + return False + for uu, vv in zip(u, v): # avoiding recursion + s = unify(uu, vv, s) + if s is False: + return False + return s + + +# +# @dispatch((set, frozenset), (set, frozenset), dict) +# def _unify(u, v, s): +# i = u & v +# u = u - i +# v = v - i +# return _unify(sorted(u), sorted(v), s) +# +# +# @dispatch(dict, dict, dict) +# def _unify(u, v, s): +# if len(u) != len(v): +# return False +# for key, uval in iteritems(u): +# if key not in v: +# return False +# s = unify(uval, v[key], s) +# if s is False: +# return False +# return s +# +# +# @dispatch(object, object, dict) +# def _unify(u, v, s): +# return False # catch all + + +@dispatch(object, object, dict) +def unify(u, v, s): # no check at the moment + """Find substitution so that u == v while satisfying s + >>> x = var("x") + >>> unify((1, x), (1, 2), {}) + {~x: 2} + """ + u = walk(u, s) + v = walk(v, s) + if u == v: + return s + if isvar(u): + return assoc(s, u, v) + if isvar(v): + return assoc(s, v, u) + return _unify(u, v, s) + + +unify + + +@dispatch(object, object) # type: ignore[no-redef] +def unify(u, v): + return unify(u, v, {}) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..82d62e1f161971cc84c2cb85c138838ed488e639 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/dispatch.py @@ -0,0 +1,8 @@ +from functools import partial + +from .multipledispatch import dispatch # type: ignore[import] + + +namespace = {} # type: ignore[var-annotated] + +dispatch = partial(dispatch, namespace=namespace) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py new file mode 100644 index 0000000000000000000000000000000000000000..01861a086f64b6121aa9e174d16176533cd0e1a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/match.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +from .core import reify, unify # type: ignore[attr-defined] +from .unification_tools import first, groupby # type: ignore[import] +from .utils import _toposort, freeze +from .variable import isvar + + +class Dispatcher: + def __init__(self, name): + self.name = name + self.funcs = {} + self.ordering = [] + + def add(self, signature, func): + self.funcs[freeze(signature)] = func + self.ordering = ordering(self.funcs) + + def __call__(self, *args, **kwargs): + func, _ = self.resolve(args) + return func(*args, **kwargs) + + def resolve(self, args): + n = len(args) + for signature in self.ordering: + if len(signature) != n: + continue + s = unify(freeze(args), signature) + if s is not False: + result = self.funcs[signature] + return result, s + raise NotImplementedError( + "No match found. \nKnown matches: " + + str(self.ordering) + + "\nInput: " + + str(args) + ) + + def register(self, *signature): + def _(func): + self.add(signature, func) + return self + + return _ + + +class VarDispatcher(Dispatcher): + """A dispatcher that calls functions with variable names + >>> # xdoctest: +SKIP + >>> d = VarDispatcher("d") + >>> x = var("x") + >>> @d.register("inc", x) + ... def f(x): + ... return x + 1 + >>> @d.register("double", x) + ... def f(x): + ... return x * 2 + >>> d("inc", 10) + 11 + >>> d("double", 10) + 20 + """ + + def __call__(self, *args, **kwargs): + func, s = self.resolve(args) + d = {k.token: v for k, v in s.items()} + return func(**d) + + +global_namespace = {} # type: ignore[var-annotated] + + +def match(*signature, **kwargs): + namespace = kwargs.get("namespace", global_namespace) + dispatcher = kwargs.get("Dispatcher", Dispatcher) + + def _(func): + name = func.__name__ + + if name not in namespace: + namespace[name] = dispatcher(name) + d = namespace[name] + + d.add(signature, func) + + return d + + return _ + + +def supercedes(a, b): + """``a`` is a more specific match than ``b``""" + if isvar(b) and not isvar(a): + return True + s = unify(a, b) + if s is False: + return False + s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} + if reify(a, s) == a: + return True + if reify(b, s) == b: + return False + + +# Taken from multipledispatch +def edge(a, b, tie_breaker=hash): + """A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + if supercedes(a, b): + if supercedes(b, a): + return tie_breaker(a) > tie_breaker(b) + else: + return True + return False + + +# Taken from multipledispatch +def ordering(signatures): + """A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(first, edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] + return _toposort(edges) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py new file mode 100644 index 0000000000000000000000000000000000000000..da2b1773f95ba096fa661cb958d849c3674c835f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/more.py @@ -0,0 +1,122 @@ +# mypy: allow-untyped-defs +from .core import reify, unify # type: ignore[attr-defined] +from .dispatch import dispatch + + +def unifiable(cls): + """Register standard unify and reify operations on class + This uses the type and __dict__ or __slots__ attributes to define the + nature of the term + See Also: + >>> # xdoctest: +SKIP + >>> class A(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + >>> unifiable(A) + + >>> x = var("x") + >>> a = A(1, 2) + >>> b = A(1, x) + >>> unify(a, b, {}) + {~x: 2} + """ + _unify.add((cls, cls, dict), unify_object) + _reify.add((cls, dict), reify_object) + + return cls + + +######### +# Reify # +######### + + +def reify_object(o, s): + """Reify a Python object with a substitution + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + ... def __str__(self): + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") + >>> f = Foo(1, x) + >>> print(f) + Foo(1, ~x) + >>> print(reify_object(f, {x: 2})) + Foo(1, 2) + """ + if hasattr(o, "__slots__"): + return _reify_object_slots(o, s) + else: + return _reify_object_dict(o, s) + + +def _reify_object_dict(o, s): + obj = object.__new__(type(o)) + d = reify(o.__dict__, s) + if d == o.__dict__: + return o + obj.__dict__.update(d) + return obj + + +def _reify_object_slots(o, s): + attrs = [getattr(o, attr) for attr in o.__slots__] + new_attrs = reify(attrs, s) + if attrs == new_attrs: + return o + else: + newobj = object.__new__(type(o)) + for slot, attr in zip(o.__slots__, new_attrs): + setattr(newobj, slot, attr) + return newobj + + +@dispatch(slice, dict) +def _reify(o, s): + """Reify a Python ``slice`` object""" + return slice(*reify((o.start, o.stop, o.step), s)) + + +######### +# Unify # +######### + + +def unify_object(u, v, s): + """Unify two Python objects + Unifies their type and ``__dict__`` attributes + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + ... def __str__(self): + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") + >>> f = Foo(1, x) + >>> g = Foo(1, 2) + >>> unify_object(f, g, {}) + {~x: 2} + """ + if type(u) != type(v): + return False + if hasattr(u, "__slots__"): + return unify( + [getattr(u, slot) for slot in u.__slots__], + [getattr(v, slot) for slot in v.__slots__], + s, + ) + else: + return unify(u.__dict__, v.__dict__, s) + + +@dispatch(slice, slice, dict) +def _unify(u, v, s): + """Unify a Python ``slice`` object""" + return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7304069243fb45604e165b06b377a5db233a7d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -0,0 +1,7 @@ +from .core import dispatch +from .dispatcher import ( + Dispatcher, + halt_ordering, + MDNotImplementedError, + restart_ordering, +) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..44a893ad56a40b69e600dca737860fd3df69e4f4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -0,0 +1,139 @@ +# mypy: allow-untyped-defs +import operator + +from .utils import _toposort, groupby +from .variadic import isvariadic + + +__all__ = [ + "AmbiguityWarning", + "supercedes", + "consistent", + "ambiguous", + "ambiguities", + "super_signature", + "edge", + "ordering", +] + + +class AmbiguityWarning(Warning): + pass + + +def supercedes(a, b): + """A is consistent and strictly more specific than B""" + if len(a) < len(b): + # only case is if a is empty and b is variadic + return not a and len(b) == 1 and isvariadic(b[-1]) + elif len(a) == len(b): + return all(map(issubclass, a, b)) + else: + # len(a) > len(b) + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not (isvariadic(cur_a) or isvariadic(cur_b)): + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + assert p1 == len(a) - 1 + return p2 == len(b) - 1 and issubclass(cur_a, cur_b) + elif isvariadic(cur_b): + assert p2 == len(b) - 1 + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + return p2 == len(b) - 1 and p1 == len(a) + + +def consistent(a, b): + """It is possible for an argument list to satisfy both A and B""" + + # Need to check for empty args + if not a: + return not b or isvariadic(b[0]) + if not b: + return not a or isvariadic(a[0]) + + # Non-empty args check for mutual subclasses + if len(a) == len(b): + return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) + else: + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): + return False + if not (isvariadic(cur_a) or isvariadic(cur_b)): + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + p2 += 1 + elif isvariadic(cur_b): + p1 += 1 + # We only need to check for variadic ends + # Variadic types are guaranteed to be the last element + return ( + isvariadic(cur_a) # type: ignore[possibly-undefined] + and p2 == len(b) + or isvariadic(cur_b) # type: ignore[possibly-undefined] + and p1 == len(a) + ) + + +def ambiguous(a, b): + """A is consistent with B but neither is strictly more specific""" + return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) + + +def ambiguities(signatures): + """All signature pairs such that A is ambiguous with B""" + signatures = list(map(tuple, signatures)) + return { + (a, b) + for a in signatures + for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) + } + + +def super_signature(signatures): + """A signature that would break ambiguities""" + n = len(signatures[0]) + assert all(len(s) == n for s in signatures) + + return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)] + + +def edge(a, b, tie_breaker=hash): + """A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + # A either supercedes B and B does not supercede A or if B does then call + # tie_breaker + return supercedes(a, b) and ( + not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) + ) + + +def ordering(signatures): + """A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(operator.itemgetter(0), edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined] + return _toposort(edges) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..f1f09dcf559c7022693bf89e1cd56d5fa01315eb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/core.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +import inspect + +from .dispatcher import Dispatcher, MethodDispatcher + + +global_namespace = {} # type: ignore[var-annotated] + +__all__ = ["dispatch", "ismethod"] + + +def dispatch(*types, **kwargs): + """Dispatch function on the types of the inputs + Supports dispatch on all non-keyword arguments. + Collects implementations based on the function name. Ignores namespaces. + If ambiguous type signatures occur a warning is raised when the function is + defined suggesting the additional method to break the ambiguity. + + Example: + >>> # xdoctest: +SKIP + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> # xdoctest: +SKIP + >>> f(3) + 4 + >>> f(3.0) + 2.0 + >>> # Specify an isolated namespace with the namespace keyword argument + >>> my_namespace = {} + >>> @dispatch(int, namespace=my_namespace) + ... def foo(x): + ... return x + 1 + >>> # Dispatch on instance methods within classes + >>> class MyClass(object): + ... @dispatch(list) + ... def __init__(self, data): + ... self.data = data + ... + ... @dispatch(int) + ... def __init__(self, datum): + ... self.data = [datum] + >>> MyClass([1, 2, 3]).data + [1, 2, 3] + >>> MyClass(3).data + [3] + """ + namespace = kwargs.get("namespace", global_namespace) + + types = tuple(types) + + def _df(func): + name = func.__name__ + + if ismethod(func): + dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr] + name, # type: ignore[union-attr] + MethodDispatcher(name), + ) + else: + if name not in namespace: + namespace[name] = Dispatcher(name) + dispatcher = namespace[name] + + dispatcher.add(types, func) + return dispatcher + + return _df + + +def ismethod(func): + """Is func a method? + Note that this has to work as the method is defined but before the class is + defined. At this stage methods look like functions. + """ + if hasattr(inspect, "signature"): + signature = inspect.signature(func) + return signature.parameters.get("self", None) is not None + else: + spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] + return spec and spec.args and spec.args[0] == "self" diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..11cc8bd59a736c3885eb5daa84072b687369f735 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -0,0 +1,453 @@ +# mypy: allow-untyped-defs +import inspect +import itertools as itl +from typing_extensions import deprecated +from warnings import warn + +from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature +from .utils import expand_tuples +from .variadic import isvariadic, Variadic + + +__all__ = [ + "MDNotImplementedError", + "ambiguity_warn", + "halt_ordering", + "restart_ordering", + "variadic_signature_matches_iter", + "variadic_signature_matches", + "Dispatcher", + "source", + "MethodDispatcher", + "str_signature", + "warning_text", +] + + +class MDNotImplementedError(NotImplementedError): + """A NotImplementedError for multiple dispatch""" + + +def ambiguity_warn(dispatcher, ambiguities): + """Raise warning when ambiguity is detected + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + See Also: + Dispatcher.add + warning_text + """ + warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) + + +@deprecated( + "`halt_ordering` is deprecated, you can safely remove this call.", + category=FutureWarning, +) +def halt_ordering(): + """Deprecated interface to temporarily disable ordering.""" + + +@deprecated( + "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, " + "you should call the `reorder()` method on each dispatcher.", + category=FutureWarning, +) +def restart_ordering(on_ambiguity=ambiguity_warn): + """Deprecated interface to temporarily resume ordering.""" + + +def variadic_signature_matches_iter(types, full_signature): + """Check if a set of input types matches a variadic signature. + Notes + ----- + The algorithm is as follows: + Initialize the current signature to the first in the sequence + For each type in `types`: + If the current signature is variadic + If the type matches the signature + yield True + Else + Try to get the next signature + If no signatures are left we can't possibly have a match + so yield False + Else + yield True if the type matches the current signature + Get the next signature + """ + sigiter = iter(full_signature) + sig = next(sigiter) + for typ in types: + matches = issubclass(typ, sig) + yield matches + if not isvariadic(sig): + # we're not matching a variadic argument, so move to the next + # element in the signature + sig = next(sigiter) + else: + try: + sig = next(sigiter) + except StopIteration: + assert isvariadic(sig) + yield True + else: + # We have signature items left over, so all of our arguments + # haven't matched + yield False + + +def variadic_signature_matches(types, full_signature): + # No arguments always matches a variadic signature + assert full_signature + return all(variadic_signature_matches_iter(types, full_signature)) + + +class Dispatcher: + """Dispatch methods based on type signature + Use ``dispatch`` to add implementations + Examples + -------- + >>> # xdoctest: +SKIP("bad import name") + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> f(3) + 4 + >>> f(3.0) + 2.0 + """ + + __slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc" + + def __init__(self, name, doc=None): + self.name = self.__name__ = name + self.funcs = {} + self.doc = doc + + self._cache = {} + + def register(self, *types, **kwargs): + """register dispatcher with new implementation + >>> # xdoctest: +SKIP + >>> f = Dispatcher("f") + >>> @f.register(int) + ... def inc(x): + ... return x + 1 + >>> @f.register(float) + ... def dec(x): + ... return x - 1 + >>> @f.register(list) + ... @f.register(tuple) + ... def reverse(x): + ... return x[::-1] + >>> f(1) + 2 + >>> f(1.0) + 0.0 + >>> f([1, 2, 3]) + [3, 2, 1] + """ + + def _df(func): + self.add(types, func, **kwargs) # type: ignore[call-arg] + return func + + return _df + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return sig.parameters.values() + + @classmethod + def get_func_annotations(cls, func): + """get annotations of function positional parameters""" + params = cls.get_func_params(func) + if params: + Parameter = inspect.Parameter + + params = ( + param + for param in params + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ) + + annotations = tuple(param.annotation for param in params) + + if all(ann is not Parameter.empty for ann in annotations): + return annotations + + def add(self, signature, func): + """Add new types/method pair to dispatcher + >>> # xdoctest: +SKIP + >>> D = Dispatcher("add") + >>> D.add((int, int), lambda x, y: x + y) + >>> D.add((float, float), lambda x, y: x + y) + >>> D(1, 2) + 3 + >>> D(1, 2.0) + Traceback (most recent call last): + ... + NotImplementedError: Could not find signature for add: + >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback + >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs + >>> # as inputs. See ``ambiguity_warn`` for an example. + """ + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + new_signature = [] + + for index, typ in enumerate(signature, start=1): + if not isinstance(typ, (type, list)): + str_sig = ", ".join( + c.__name__ if isinstance(c, type) else str(c) for c in signature + ) + raise TypeError( + f"Tried to dispatch on non-type: {typ}\n" + f"In signature: <{str_sig}>\n" + f"In function: {self.name}" + ) + + # handle variadic signatures + if isinstance(typ, list): + if index != len(signature): + raise TypeError("Variadic signature must be the last element") + + if len(typ) != 1: + raise TypeError( + "Variadic signature must contain exactly one element. " + "To use a variadic union type place the desired types " + "inside of a tuple, e.g., [(int, str)]" + ) + new_signature.append(Variadic[typ[0]]) + else: + new_signature.append(typ) + + self.funcs[tuple(new_signature)] = func + self._cache.clear() + + try: + del self._ordering + except AttributeError: + pass + + @property + def ordering(self): + try: + return self._ordering + except AttributeError: + return self.reorder() + + def reorder(self, on_ambiguity=ambiguity_warn): + self._ordering = od = ordering(self.funcs) + amb = ambiguities(self.funcs) + if amb: + on_ambiguity(self, amb) + return od + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + try: + func = self._cache[types] + except KeyError as e: + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) from e + self._cache[types] = func + try: + return func(*args, **kwargs) + + except MDNotImplementedError as e: + funcs = self.dispatch_iter(*types) + next(funcs) # burn first + for func in funcs: + try: + return func(*args, **kwargs) + except MDNotImplementedError: + pass + + raise NotImplementedError( + "Matching functions for " + f"{self.name}: <{str_signature(types)}> found, but none completed successfully", + ) from e + + def __str__(self): + return f"" + + __repr__ = __str__ + + def dispatch(self, *types): + """Determine appropriate implementation for this type signature + This method is internal. Users should call this object as a function. + Implementation resolution occurs within the ``__call__`` method. + >>> # xdoctest: +SKIP + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def inc(x): + ... return x + 1 + >>> implementation = inc.dispatch(int) + >>> implementation(3) + 4 + >>> print(inc.dispatch(float)) + None + See Also: + ``multipledispatch.conflict`` - module to determine resolution order + """ + + if types in self.funcs: + return self.funcs[types] + + try: + return next(self.dispatch_iter(*types)) + except StopIteration: + return None + + def dispatch_iter(self, *types): + n = len(types) + for signature in self.ordering: + if len(signature) == n and all(map(issubclass, types, signature)): + result = self.funcs[signature] + yield result + elif len(signature) and isvariadic(signature[-1]): + if variadic_signature_matches(types, signature): + result = self.funcs[signature] + yield result + + @deprecated( + "`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning + ) + def resolve(self, types): + """Determine appropriate implementation for this type signature + .. deprecated:: 0.4.4 + Use ``dispatch(*types)`` instead + """ + return self.dispatch(*types) + + def __getstate__(self): + return {"name": self.name, "funcs": self.funcs} + + def __setstate__(self, d): + self.name = d["name"] + self.funcs = d["funcs"] + self._ordering = ordering(self.funcs) + self._cache = {} + + @property + def __doc__(self): # type: ignore[override] + docs = [f"Multiply dispatched method: {self.name}"] + + if self.doc: + docs.append(self.doc) + + other = [] + for sig in self.ordering[::-1]: + func = self.funcs[sig] + if func.__doc__: + s = f"Inputs: <{str_signature(sig)}>\n" + s += "-" * len(s) + "\n" + s += func.__doc__.strip() + docs.append(s) + else: + other.append(str_signature(sig)) + + if other: + docs.append("Other signatures:\n " + "\n ".join(other)) + + return "\n\n".join(docs) + + def _help(self, *args): + return self.dispatch(*map(type, args)).__doc__ + + def help(self, *args, **kwargs): + """Print docstring for the function corresponding to inputs""" + print(self._help(*args)) + + def _source(self, *args): + func = self.dispatch(*map(type, args)) + if not func: + raise TypeError("No function found") + return source(func) + + def source(self, *args, **kwargs): + """Print source code for the function corresponding to inputs""" + print(self._source(*args)) + + +def source(func): + s = f"File: {inspect.getsourcefile(func)}\n\n" + s = s + inspect.getsource(func) + return s + + +class MethodDispatcher(Dispatcher): + """Dispatch methods based on type signature + See Also: + Dispatcher + """ + + __slots__ = ("obj", "cls") + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return itl.islice(sig.parameters.values(), 1, None) + + def __get__(self, instance, owner): + self.obj = instance + self.cls = owner + return self + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) + return func(self.obj, *args, **kwargs) + + +def str_signature(sig): + """String representation of type signature + >>> str_signature((int, float)) + 'int, float' + """ + return ", ".join(cls.__name__ for cls in sig) + + +def warning_text(name, amb): + """The text for ambiguity warnings""" + text = f"\nAmbiguities exist in dispatched function {name}\n\n" + text += "The following signatures may result in ambiguous behavior:\n" + for pair in amb: + text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" + text += "\n\nConsider making the following additions:\n\n" + text += "\n\n".join( + [ + "@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)" + for s in amb + ] + ) + return text diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c91cca2067afcd406aa35c51363417ca4ada2e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py @@ -0,0 +1,127 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict + + +__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] + + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +def expand_tuples(L): + """ + >>> expand_tuples([1, (2, 3)]) + [(1, 2), (1, 3)] + >>> expand_tuples([1, 2]) + [(1, 2)] + """ + if not L: + return [()] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> _toposort({1: (2, 3), 2: (3,)}) + [1, 2, 3] + >>> # Closely follows the wikipedia page [2] + >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + >>> # Communications of the ACM + >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) + S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) + L = [] + + while S: + n, _ = S.popitem() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S[m] = None + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = OrderedDict() # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key,) + return result + + +# Taken from toolz +# Avoids licensing issues because this version was authored by Matthew Rocklin +def groupby(func, seq): + """Group a collection by a key function + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + See Also: + ``countby`` + """ + + d = OrderedDict() # type: ignore[var-annotated] + for item in seq: + key = func(item) + if key not in d: + d[key] = [] + d[key].append(item) + return d + + +def typename(type): + """Get the name of `type`. + Parameters + ---------- + type : Union[Type, Tuple[Type]] + Returns + ------- + str + The name of `type` or a tuple of the names of the types in `type`. + Examples + -------- + >>> typename(int) + 'int' + >>> typename((int, float)) + '(int, float)' + """ + try: + return type.__name__ + except AttributeError: + if len(type) == 1: + return typename(*type) + return f"({', '.join(map(typename, type))})" diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5604a152480f83916108cb1b02de3bc9b9adb5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +from .utils import typename + + +__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] + + +class VariadicSignatureType(type): + # checking if subclass is a subclass of self + def __subclasscheck__(cls, subclass): + other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) + return subclass is cls or all( + issubclass(other, cls.variadic_type) # type: ignore[attr-defined] + for other in other_type + ) + + def __eq__(cls, other): + """ + Return True if other has the same variadic type + Parameters + ---------- + other : object (type) + The object (type) to check + Returns + ------- + bool + Whether or not `other` is equal to `self` + """ + return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type) # type: ignore[attr-defined] + + def __hash__(cls): + return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] + + +def isvariadic(obj): + """Check whether the type `obj` is variadic. + Parameters + ---------- + obj : type + The type to check + Returns + ------- + bool + Whether or not `obj` is variadic + Examples + -------- + >>> # xdoctest: +SKIP + >>> isvariadic(int) + False + >>> isvariadic(Variadic[int]) + True + """ + return isinstance(obj, VariadicSignatureType) + + +class VariadicSignatureMeta(type): + """A metaclass that overrides ``__getitem__`` on the class. This is used to + generate a new type for Variadic signatures. See the Variadic class for + examples of how this behaves. + """ + + def __getitem__(cls, variadic_type): + if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): + raise ValueError( + "Variadic types must be type or tuple of types" + " (Variadic[int] or Variadic[(int, float)]" + ) + + if not isinstance(variadic_type, tuple): + variadic_type = (variadic_type,) + return VariadicSignatureType( + f"Variadic[{typename(variadic_type)}]", + (), + dict(variadic_type=variadic_type, __slots__=()), + ) + + +class Variadic(metaclass=VariadicSignatureMeta): + """A class whose getitem method can be used to generate a new type + representing a specific variadic signature. + Examples + -------- + >>> # xdoctest: +SKIP + >>> Variadic[int] # any number of int arguments + + >>> Variadic[(int, str)] # any number of one of int or str arguments + + >>> issubclass(int, Variadic[int]) + True + >>> issubclass(int, Variadic[(int, str)]) + True + >>> issubclass(str, Variadic[(int, str)]) + True + >>> issubclass(float, Variadic[(int, str)]) + False + """ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a47d900273f5ea4d1fcbeae1be35f8685f5b0a32 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/unification_tools.py @@ -0,0 +1,419 @@ +# mypy: allow-untyped-defs +import collections +import operator +from collections.abc import Mapping +from functools import reduce + + +__all__ = [ + "merge", + "merge_with", + "valmap", + "keymap", + "itemmap", + "valfilter", + "keyfilter", + "itemfilter", + "assoc", + "dissoc", + "assoc_in", + "update_in", + "get_in", +] + + +def _get_factory(f, kwargs): + factory = kwargs.pop("factory", dict) + if kwargs: + raise TypeError( + f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" + ) + return factory + + +def merge(*dicts, **kwargs): + """Merge a collection of dictionaries + + >>> merge({1: "one"}, {2: "two"}) + {1: 'one', 2: 'two'} + + Later dictionaries have precedence + + >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) + {1: 2, 3: 3, 4: 4} + + See Also: + merge_with + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge, kwargs) + + rv = factory() + for d in dicts: + rv.update(d) + return rv + + +def merge_with(func, *dicts, **kwargs): + """Merge dictionaries and apply function to combined values + + A key may occur in more than one dict, and all values mapped from the key + will be passed to the function as a list, such as func([val1, val2, ...]). + + >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) + {1: 11, 2: 22} + + >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP + {1: 1, 2: 2, 3: 30} + + See Also: + merge + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge_with, kwargs) + + result = factory() + for d in dicts: + for k, v in d.items(): + if k not in result: + result[k] = [v] + else: + result[k].append(v) + return valmap(func, result, factory) + + +def valmap(func, d, factory=dict): + """Apply function to values of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> valmap(sum, bills) # doctest: +SKIP + {'Alice': 65, 'Bob': 45} + + See Also: + keymap + itemmap + """ + rv = factory() + rv.update(zip(d.keys(), map(func, d.values()))) + return rv + + +def keymap(func, d, factory=dict): + """Apply function to keys of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> keymap(str.lower, bills) # doctest: +SKIP + {'alice': [20, 15, 30], 'bob': [10, 35]} + + See Also: + valmap + itemmap + """ + rv = factory() + rv.update(zip(map(func, d.keys()), d.values())) + return rv + + +def itemmap(func, d, factory=dict): + """Apply function to items of dictionary + + >>> accountids = {"Alice": 10, "Bob": 20} + >>> itemmap(reversed, accountids) # doctest: +SKIP + {10: "Alice", 20: "Bob"} + + See Also: + keymap + valmap + """ + rv = factory() + rv.update(map(func, d.items())) + return rv + + +def valfilter(predicate, d, factory=dict): + """Filter items in dictionary by value + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> valfilter(iseven, d) + {1: 2, 3: 4} + + See Also: + keyfilter + itemfilter + valmap + """ + rv = factory() + for k, v in d.items(): + if predicate(v): + rv[k] = v + return rv + + +def keyfilter(predicate, d, factory=dict): + """Filter items in dictionary by key + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> keyfilter(iseven, d) + {2: 3, 4: 5} + + See Also: + valfilter + itemfilter + keymap + """ + rv = factory() + for k, v in d.items(): + if predicate(k): + rv[k] = v + return rv + + +def itemfilter(predicate, d, factory=dict): + """Filter items in dictionary by item + + >>> def isvalid(item): + ... k, v = item + ... return k % 2 == 0 and v < 4 + + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> itemfilter(isvalid, d) + {2: 3} + + See Also: + keyfilter + valfilter + itemmap + """ + rv = factory() + for item in d.items(): + if predicate(item): + k, v = item + rv[k] = v + return rv + + +def assoc(d, key, value, factory=dict): + """Return a new dict with new key value pair + + New dict has d[key] set to value. Does not modify the initial dictionary. + + >>> assoc({"x": 1}, "x", 2) + {'x': 2} + >>> assoc({"x": 1}, "y", 3) # doctest: +SKIP + {'x': 1, 'y': 3} + """ + d2 = factory() + d2.update(d) + d2[key] = value + return d2 + + +def dissoc(d, *keys, **kwargs): + """Return a new dict with the given key(s) removed. + + New dict has d[key] deleted for each supplied key. + Does not modify the initial dictionary. + + >>> dissoc({"x": 1, "y": 2}, "y") + {'x': 1} + >>> dissoc({"x": 1, "y": 2}, "y", "x") + {} + >>> dissoc({"x": 1}, "y") # Ignores missing keys + {'x': 1} + """ + factory = _get_factory(dissoc, kwargs) + d2 = factory() + + if len(keys) < len(d) * 0.6: + d2.update(d) + for key in keys: + if key in d2: + del d2[key] + else: + remaining = set(d) + remaining.difference_update(keys) + for k in remaining: + d2[k] = d[k] + return d2 + + +def assoc_in(d, keys, value, factory=dict): + """Return a new dict with new, potentially nested, key value pair + + >>> purchase = { + ... "name": "Alice", + ... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} + """ + return update_in(d, keys, lambda x: value, value, factory) + + +def update_in(d, keys, func, default=None, factory=dict): + """Update value in a (potentially) nested dictionary + + inputs: + d - dictionary on which to operate + keys - list or tuple giving the location of the value to be changed in d + func - function to operate on that value + + If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the + original dictionary with v replaced by func(v), but does not mutate the + original dictionary. + + If k0 is not a key in d, update_in creates nested dictionaries to the depth + specified by the keys, with the innermost value set to func(default). + + >>> inc = lambda x: x + 1 + >>> update_in({"a": 0}, ["a"], inc) + {'a': 1} + + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} + + >>> # updating a value when k0 is not in d + >>> update_in({}, [1, 2, 3], str, default="bar") + {1: {2: {3: 'bar'}}} + >>> update_in({1: "foo"}, [2, 3, 4], inc, 0) + {1: 'foo', 2: {3: {4: 1}}} + """ + ks = iter(keys) + k = next(ks) + + rv = inner = factory() + rv.update(d) + + for key in ks: + if k in d: + d = d[k] + dtemp = factory() + dtemp.update(d) + else: + d = dtemp = factory() + + inner[k] = inner = dtemp + k = key + + if k in d: + inner[k] = func(d[k]) + else: + inner[k] = func(default) + return rv + + +def get_in(keys, coll, default=None, no_default=False): + """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. + + If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless + ``no_default`` is specified, then it raises KeyError or IndexError. + + ``get_in`` is a generalization of ``operator.getitem`` for nested data + structures such as dictionaries and lists. + + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> get_in(["purchase", "items", 0], transaction) + 'Apple' + >>> get_in(["name"], transaction) + 'Alice' + >>> get_in(["purchase", "total"], transaction) + >>> get_in(["purchase", "items", "apple"], transaction) + >>> get_in(["purchase", "items", 10], transaction) + >>> get_in(["purchase", "total"], transaction, 0) + 0 + >>> get_in(["y"], {}, no_default=True) + Traceback (most recent call last): + ... + KeyError: 'y' + + See Also: + itertoolz.get + operator.getitem + """ + try: + return reduce(operator.getitem, keys, coll) + except (KeyError, IndexError, TypeError): + if no_default: + raise + return default + + +def getter(index): + if isinstance(index, list): + if len(index) == 1: + index = index[0] + return lambda x: (x[index],) + elif index: + return operator.itemgetter(*index) + else: + return lambda x: () + else: + return operator.itemgetter(index) + + +def groupby(key, seq): + """Group a collection by a key function + + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + + Non-callable keys imply grouping on a member. + + >>> groupby( + ... "gender", + ... [ + ... {"name": "Alice", "gender": "F"}, + ... {"name": "Bob", "gender": "M"}, + ... {"name": "Charlie", "gender": "M"}, + ... ], + ... ) # doctest:+SKIP + {'F': [{'gender': 'F', 'name': 'Alice'}], + 'M': [{'gender': 'M', 'name': 'Bob'}, + {'gender': 'M', 'name': 'Charlie'}]} + + Not to be confused with ``itertools.groupby`` + + See Also: + countby + """ + if not callable(key): + key = getter(key) + d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] + for item in seq: + d[key(item)](item) + rv = {} + for k, v in d.items(): + rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] + return rv + + +def first(seq): + """The first element in a sequence + + >>> first("ABC") + 'A' + """ + return next(iter(seq)) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7634c9b2ec90b870143954f741c8eb3be01d8d6b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/utils.py @@ -0,0 +1,108 @@ +# mypy: allow-untyped-defs +__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] + + +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + + +def transitive_get(key, d): + """Transitive dict.get + >>> d = {1: 2, 2: 3, 3: 4} + >>> d.get(1) + 2 + >>> transitive_get(1, d) + 4 + """ + while hashable(key) and key in d: + key = d[key] + return key + + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> # xdoctest: +SKIP + >>> _toposort({1: (2, 3), 2: (3,)}) + [1, 2, 3] + Closely follows the wikipedia page [2] + [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + Communications of the ACM + [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = {k: set(val) for k, val in incoming_edges.items()} + S = {v for v in edges if v not in incoming_edges} + L = [] + + while S: + n = S.pop() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S.add(m) + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = {} # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key,) + return result + + +def xfail(func): + try: + func() + raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002 + except Exception: + pass + + +def freeze(d): + """Freeze container to hashable form + >>> freeze(1) + 1 + >>> freeze([1, 2]) + (1, 2) + >>> freeze({1: 2}) # doctest: +SKIP + frozenset([(1, 2)]) + """ + if isinstance(d, dict): + return frozenset(map(freeze, d.items())) + if isinstance(d, set): + return frozenset(map(freeze, d)) + if isinstance(d, (tuple, list)): + return tuple(map(freeze, d)) + return d diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..46e59851fdfa8389e29288176a50dc62fb568654 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unification/variable.py @@ -0,0 +1,90 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +from .dispatch import dispatch +from .utils import hashable + + +_global_logic_variables = set() # type: ignore[var-annotated] +_glv = _global_logic_variables + + +class Var: + """Logic Variable""" + + _id = 1 + + def __new__(cls, *token): + if len(token) == 0: + token = f"_{Var._id}" # type: ignore[assignment] + Var._id += 1 + elif len(token) == 1: + token = token[0] + + obj = object.__new__(cls) + obj.token = token # type: ignore[attr-defined] + return obj + + def __str__(self): + return "~" + str(self.token) # type: ignore[attr-defined] + + __repr__ = __str__ + + def __eq__(self, other): + return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined] + + def __hash__(self): + return hash((type(self), self.token)) # type: ignore[attr-defined] + + +def var(): + return lambda *args: Var(*args) + + +def vars(): + return lambda n: [var() for i in range(n)] + + +@dispatch(Var) +def isvar(v): + return True + + +isvar + + +@dispatch(object) # type: ignore[no-redef] +def isvar(o): + return not not _glv and hashable(o) and o in _glv + + +@contextmanager +def variables(*variables): + """ + Context manager for logic variables + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> from __future__ import with_statement + >>> with variables(1): + ... print(isvar(1)) + True + >>> print(isvar(1)) + False + >>> # Normal approach + >>> from unification import unify + >>> x = var("x") + >>> unify(x, 1) + {~x: 1} + >>> # Context Manager approach + >>> with variables("x"): + ... print(unify("x", 1)) + {'x': 1} + """ + old_global_logic_variables = _global_logic_variables.copy() + _global_logic_variables.update(set(variables)) + try: + yield + finally: + _global_logic_variables.clear() + _global_logic_variables.update(old_global_logic_variables) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py new file mode 100644 index 0000000000000000000000000000000000000000..bab662e0655a2c7c4049ff9b8ae50341567c1259 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/unify_refinements.py @@ -0,0 +1,124 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined] +from torch.fx.tensor_type import TensorType + + +def infer_symbolic_types_single_pass(traced): + """ + Calls our symbolic inferencer once. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + +def infer_symbolic_types(traced): + """ + Calls our symbolic inferencer twice. + This is useful when one pass is not enough + to infer all the information such as the case + for braodcasting. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r.symbolic_relations() + + +def convert_eq(list_of_eq): + """ + Convert equality constraints in the right format + to be used by unification library. + """ + lhs = [] + rhs = [] + for eq in list_of_eq: + lhs.append(eq.lhs) + rhs.append(eq.rhs) + return tuple(lhs), tuple(rhs) + + +def unify_eq(list_of_eq): + """ + Apply unification to a set of + equality constraints + """ + lhs, rhs = convert_eq(list_of_eq) + return unify(lhs, rhs) + + +def substitute_solution_one_type(mapping, t): + """ + Apply the most general unifier to a type + """ + if isinstance(t, Var): + if t in mapping.keys(): + return mapping[t] + else: + return t + + elif isinstance(t, TensorType): + new_type = [] + for typ in t.__args__: + if typ in mapping.keys(): + new_type.append(mapping[typ]) + else: + new_type.append(typ) + return TensorType(tuple(new_type)) + + elif isinstance(t, list): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return new_type + + elif isinstance(t, tuple): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return tuple(new_type) + + else: + return t + + +def substitute_all_types(graph, mapping): + """ + Apply the most general unifier to all types in a graph + till reaching a fixed point. If the input and output graph + are the same, we converge. + """ + flag = True + while flag: + flag = False + for k in mapping: + old_mapping_val = mapping[k] + if mapping[k] in mapping.keys(): + new_key = mapping[k] + mapping[k] = mapping[new_key] + if old_mapping_val != mapping[k]: + flag = True + + for n in graph.nodes: + n.type = substitute_solution_one_type(mapping, n.type) + + +def check_for_type_equality(g1, g2): + """ + A check equality to be used in fixed points. + We do not use graph equality but instead type + equality. + """ + for n, m in zip(g1.nodes, g2.nodes): + if n.type != m.type: + return False + return True diff --git a/.venv/lib/python3.12/site-packages/torch/fx/experimental/validator.py b/.venv/lib/python3.12/site-packages/torch/fx/experimental/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..9d70973225db820a0ee301f3d42c2791c6236936 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/experimental/validator.py @@ -0,0 +1,869 @@ +# mypy: allow-untyped-defs +import builtins +import functools +import logging +import math +import operator +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import sympy + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +from torch._dynamo.exc import TorchDynamoException +from torch._dynamo.utils import dynamo_timed +from torch.fx.node import Argument, Target +from torch.utils._sympy.interp import sympy_interp + + +log = logging.getLogger(__name__) + +try: + import z3 # type: ignore[import] + + # Translation Validation for Dynamo guards + # ======================================== + # + # Checks whether optimizations applied to the collected guards are + # valid. In other words, whether the guard function we actually run + # does not have false positives (unsound). + # + # In order to do so, we build the guards using 2 different information + # attached to each 'SymNode': + # 1. SymPy expressions + # 2. FX nodes + # + # SymPy expressions have implicit optimizations baked within itself, + # which may have a few bugs. On the other hand, we build the FX graph + # manually, with no optimizations enabled. This gives us access to + # the "ground truth". + # + # We then convert into Z3 expressions both the SymPy expressions + # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function + # and the FX nodes (see [Note: PopulateValidator]) that go through + # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. + # (see [Note: TranslationValidator]) + # Better Z3 to string implementation (for a small fraction of Z3). + # + # Here are the things we clean before showing the Z3 expression: + # - Rename a few ops (e.g. "Distinct" ==> "!=") + # + # - Ignore ToInt and ToReal operations: + # usually they don't really matter + # + # - Transform (ToInt (/ ...)) into (idiv ...): + # this is the pattern for floor division + # + # - Collect a chain of the same operations into one + def z3str(e: z3.ExprRef) -> str: + assert z3.is_expr(e), f"unsupported expression type: {e}" + + def get_args_str(e: z3.ExprRef) -> list[str]: + return [z3str(e.arg(i)) for i in range(e.num_args())] + + # First, we simplify the given expression. + # This is done using rewriting rules, so shouldn't take long. + e = z3.simplify(e) + + # Only support function applications. + # Even Z3 "variables" are, in fact, function applications. + if not z3.is_app(e): + raise ValueError(f"can't print Z3 expression: {e}") + + if z3.is_int_value(e) or z3.is_rational_value(e): + return e.as_string() # type: ignore[attr-defined] + + decl = e.decl() + kind = decl.kind() + op = str(decl) + args = get_args_str(e) + + if kind == z3.Z3_OP_POWER: + op = "pow" + + elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): + # Collect the arguments of chains of ADD and MUL. + # This is safe, since they are associative. + + def collect_str_args(e): + if not (z3.is_app(e) and e.decl().kind() == kind): + return [z3str(e)] + else: + return [ + x + for i in range(e.num_args()) + for x in collect_str_args(e.arg(i)) + ] + + args = collect_str_args(e) + + elif kind == z3.Z3_OP_NOT: + # Revert some conversions that z3.simplify applies: + # - a != b ==> (Not (== a b)) ==> (!= a b) + # - a < b ==> (Not (<= b a)) ==> (> b a) + # - a > b ==> (Not (<= a b)) ==> (> a b) + + assert e.num_args() == 1 + arg = e.arg(0) + + assert z3.is_app(arg) + argkind = arg.decl().kind() + + logic_inverse = { + z3.Z3_OP_EQ: "!=", + z3.Z3_OP_LE: ">", + z3.Z3_OP_GE: "<", + } + + if argkind in logic_inverse: + op = logic_inverse[argkind] + args = get_args_str(arg) + + elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): + assert e.num_args() == 1 + argstr = z3str(e.arg(0)) + + # Check if it's the floor division pattern. + if argstr.startswith("(/"): + return "(idiv" + argstr[2:] + + # Otherwise, just ignore it. + return argstr + + elif kind == z3.Z3_OP_UNINTERPRETED: + assert e.num_args() == 0 + return str(decl) + + string = op + " " + " ".join(args) + return f"({string.rstrip()})" + + # We need to convert to/from BitVec in order to use z3 bitwise ops. + # We assume that integers are 64 bit. + # If all args are boolean, then use the boolean bitwise op implementation instead, if provided. + def _bitwise_op(bitwise_func, bool_func): + @functools.wraps(bitwise_func) + def wrapper(self, *args): + if bool_func is not None and all( + isinstance(arg, z3.BoolRef) for arg in args + ): + return bool_func(*args) + + wrapped_args = tuple(z3.Int2BV(a, 64) for a in args) + return z3.BV2Int(bitwise_func(*wrapped_args)) + + return wrapper + + # Implementation of Python semantics as Z3 expressions. + # + # Z3 Real-Int theory has operators with semantics that differ that of + # Python. Therefore, in order to get it right, we need to implement + # the (Python) semantics we are relying on in Z3. + @dataclass + class _Z3Ops: + # Validator used for adding assertions as needed. + # e.g. div(a, b) requires b != 0. + validator: "TranslationValidator" + + # The 2 functions below are used for conditionally casting between + # integer and reals. + # + # Returns a real expression from 'x'. + @staticmethod + def to_real(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_real() else z3.ToReal(x) + + # Returns an integer expression from 'x'. + @staticmethod + def to_int(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_int() else z3.ToInt(x) + + def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef: + return sum(args) + + # Implements Python division semantics. + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + self.validator.add_assertion(denominator != 0) # type: ignore[arg-type] + return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator) + + def floor(self, number: z3.ArithRef) -> z3.ArithRef: + # Z3 ToInt function rounds a real number towards negative infinity. + return _Z3Ops.to_int(number) + + # Python semantics for 'FloorDiv' states that before applying the floor + # function, the operands are converted to their common type. + def floordiv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + cast_result_to_real = numerator.is_real() or denominator.is_real() + result = _Z3Ops.to_int(self.div(numerator, denominator)) + # Since the 'result' is already an integer, we just have to check + # whether we should cast it to real. + return _Z3Ops.to_real(result) if cast_result_to_real else result + + def ceil(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value] + + def trunc(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value] + + def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a > b, a, b) # type: ignore[return-value] + + def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a < b, a, b) # type: ignore[return-value] + + # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q + # It should work with both integer and reals. + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return p - self.floordiv(p, q) * q + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + # Z3 can't handle complex numbers very well. + self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type] + return base**exp + + def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: + # Square-root: + # 1. Only work with reals + number = _Z3Ops.to_real(number) + # 2. The number should be positive or zero. + # Otherwise, Z3 returns 'unknown'. + self.validator.add_assertion(number >= 0) + return number**0.5 + + def abs(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.Abs(number) + + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: + # Pythons builtin 'round' implements the 'round half to even' strategy + # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even + # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to + # floating point numbers, which is different from real numbers that we are dealing with here. + # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and + # 'round half down' (ceil(x - 0.5)). + # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ... + # to round down, i.e. use the 'round half down' strategy + return z3.If( + self.mod(number, z3.IntVal(2)) == 0.5, + self.ceil(number - 0.5), + self.floor(number + 0.5), + ) + + bitwise_and = _bitwise_op(operator.and_, z3.And) + bitwise_or = _bitwise_op(operator.or_, z3.Or) + lshift = _bitwise_op(operator.lshift, None) + rshift = _bitwise_op(operator.rshift, None) + + # Lifts a callable to be used in Z3. + # + # This function replaces the given 'op' by a function that: + # + # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3) + # + # 2. Calls an operation that corresponds to 'op', but works with Z3 + # inhabitants (left as is if it works as is) + def z3op(op: Callable, validator: "TranslationValidator") -> Callable: + # Operations that have booleans as their argument. + # This is needed because the argument of some FX nodes were + # literal integers, instead of booleans. So, whenever this flag + # is set, we also convert ints to booleans. + boolean_ops = {operator.not_} + as_bool = op in boolean_ops + + # Lifts the function into 'z3.ExprRef' domain. + def lift(func): + def wrap(a) -> z3.ExprRef: + if isinstance(a, (z3.ArithRef, z3.BoolRef)): + return a + # Convert it into a Z3 value, if it is some of the supported + # types below. + if isinstance(a, bool) or (as_bool and isinstance(a, int)): + return z3.BoolVal(bool(a)) + if isinstance(a, (int, sympy.Integer)): + return z3.IntVal(int(a)) + if isinstance(a, (float, sympy.Float)): + return z3.RealVal(float(a)) + raise ValueError(f"can't lift type: {type(a)}") + + @functools.wraps(func) + def wrapper(*args): + # Lifts the arguments into a list of Z3 inhabitants. + if len(args) == 1 and isinstance(args[0], (list, tuple)): + wrapped_args = (tuple(wrap(a) for a in args[0]),) + else: + wrapped_args = tuple(wrap(a) for a in args) + # Run the function on the Z3 expressions. + return func(*wrapped_args) + + return wrapper + + ops = _Z3Ops(validator) + replacement_map = { + # Operator module. + operator.not_: lift(z3.Not), + operator.and_: lift(ops.bitwise_and), + operator.or_: lift(ops.bitwise_or), + operator.lshift: lift(ops.lshift), + operator.rshift: lift(ops.rshift), + operator.floordiv: lift(ops.floordiv), + operator.truediv: lift(ops.div), + operator.mod: lift(ops.mod), + operator.abs: lift(ops.abs), + builtins.round: lift(ops.round_to_int), + # Math module. + math.ceil: lift(ops.ceil), + math.floor: lift(ops.floor), + math.trunc: lift(ops.trunc), + # Torch module. + torch.sym_float: lift(ops.to_real), + torch.sym_max: lift(ops.max), + torch.sym_min: lift(ops.min), + torch.sym_sum: lift(ops.sym_sum), + torch.sym_ite: lift(lambda b, t, f: t if b else f), + torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] + # Not lifted because we only use this function as a + # marker for adding the expression as validator input. + torch._assert: torch._assert, + } + return replacement_map[op] if op in replacement_map else lift(op) + + # Processes an FX graph, populating the given validator. + # + # [Note: PopulateValidator] + # This class walks through each node in the FX graph, translating + # them into the Z3 world. + # + # Then, whenever it finds an 'torch._assert' call_function operation, + # it adds the Z3 expression corresponding to the argument as validator + # input. + class PopulateValidator(torch.fx.Interpreter): + def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): + # Reference to the translation validator. + self.validator = validator + + # Build the graph module and call `Interpreter` constructor. + module = torch.fx.GraphModule(root={}, graph=graph) + super().__init__(module, garbage_collect_values=True) + + def placeholder( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + symbol = fx_traceback.get_current_meta()["symbol"] + return self.validator.z3var(symbol) + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + if target != torch._assert: + # Lift and runs the node target function + return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] + # Adds the Z3 expression corresponding to the first argument + # as a validator input. + assert len(args) == 1, ( + f"expected 1 argument on assertion. Got: {len(args)} " + ) + self.validator.add_source_expr(args[0]) # type: ignore[arg-type] + + # Translates SymPy expressions into Z3 expressions. + # + # [Note: SympyToZ3] + # At the time of the translation, all free variables present in the + # SymPy expression being translated must be already mapped to a Z3 + # integer variable. + class SympyToZ3: + OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"} + + def __init__( + self, + validator: "TranslationValidator", + ) -> None: + self._validator = validator + self._ops = _Z3Ops(self._validator) + + def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision + if dtype is torch.int64: + return z3.IntVal(int(value)) + if dtype is torch.double: + return z3.RealVal(float(value)) + if dtype is torch.bool: + return z3.BoolVal(bool(value)) + raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def truediv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def floordiv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return self._ops.mod(p, q) + + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) + + def __getattr__(self, name: str) -> Any: + REPLACEMENT = { + "and_": z3.And, + "or_": z3.Or, + "not_": z3.Not, + "bitwise_and": self._ops.bitwise_and, + "bitwise_or": self._ops.bitwise_or, + "lshift": self._ops.lshift, + "rshift": self._ops.rshift, + "floor": self._ops.floor, + "ceil": self._ops.ceil, + "minimum": self._ops.min, + "maximum": self._ops.max, + } + + if name in REPLACEMENT: + return REPLACEMENT[name] + if name in self.OPERATOR_HANDLES: + return getattr(operator, name) + raise AttributeError(f"unhandled operator: {name}") + + def run(self, expr: sympy.Basic) -> z3.ExprRef: + return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type] + + # Dynamo guards translation validator. + # + # [Note: TranslationValidator] + # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound. + # That is: whether those (target) guards only yield TRUE whenever the original, + # unoptimized, (source) guards yield TRUE. + # + # More concretely, given 'source' and 'target' guard expressions, we wish to + # check whether the following expression holds: + # + # Not(And(source)) AND And(target) + # + # i.e. whether there is an assignment of the free variables where the opposite + # happens: target is TRUE, but source is FALSE. + class TranslationValidator: + def __init__(self) -> None: + log.debug("new instance") + + # Mapping of SymPy symbols to Z3 variables. + self.symbols: dict[sympy.Symbol, z3.ExprRef] = {} + + # Set of source Z3 expressions. + # They represent the generated guards without any kind of + # simplification or transformation. + self._source_exprs: set[z3.BoolRef] = set() + + # Set of target Z3 expressions. + # They represent the actual checked guards at runtime. They might + # be simplified or transformed versions of the source guards. + self._target_exprs: set[z3.BoolRef] = set() + + # Set of Z3 expressions representing assertions over both the + # source and target expressions. + self._assertions: set[z3.BoolRef] = set() + + # Retrieves the corresponding Z3 variable. + def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: + assert symbol in self.symbols, f"Z3 variable not found for: {symbol}" + return self.symbols[symbol] + + # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. + def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef: + if symbol in self.symbols: + return self.symbols[symbol] + + log.debug("new variable: %s (%s)", symbol.name, type.__name__) + + if type is int: + var = z3.Int(symbol.name) + + # If 'symbol' is positive (SymPy assumption), we have to + # convey it to Z3 as well. + if symbol.is_positive: # type: ignore[attr-defined] + self._target_exprs.add(var > 0) + elif type is float: + var = z3.Real(symbol.name) + elif type is bool: + var = z3.Bool(symbol.name) + else: + raise RuntimeError(f"unsupported type for Z3 variable: {type}") + + self.symbols[symbol] = var + return var + + # Checks whether all symbols were already added. + def _check_freesymbols(self, e: sympy.Basic) -> None: + for s in e.free_symbols: + assert isinstance(s, sympy.Symbol) + # Call 'z3var' just to check whether there's already a + # Z3 variable corresponding to 's'. + self.z3var(s) + + def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: + z3expr = SympyToZ3(self).run(e) + assert isinstance(z3expr, z3.BoolRef), ( + f"expected boolean expression. Got: {z3expr}" + ) + return z3expr + + def add_source_expr(self, e: z3.BoolRef) -> None: + if e not in self._source_exprs: + log.debug("add source guard: %s", z3str(e)) + self._source_exprs.add(e) + + def add_target_expr(self, e: "sympy.logic.boolalg.Boolean") -> None: + self._check_freesymbols(e) + z3expr = self.to_z3_boolean_expr(e) + if e not in self._target_exprs: + log.debug("add target guard: %s", z3str(z3expr)) + self._target_exprs.add(z3expr) + + def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: + if isinstance(e, sympy.Basic): + self._check_freesymbols(e) + ref = self.to_z3_boolean_expr(e) + else: + ref = e + assert isinstance(ref, z3.BoolRef) + if ref not in self._assertions: + log.debug("add assertion: %s", z3str(ref)) + self._assertions.add(ref) + + def validate(self) -> None: + with dynamo_timed("TranslationValidator.validate"): + return self._validate() + + def _validate(self) -> None: + if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: + # If there are no source/target expressions, there's nothing we really + # wish to prove. So, we just return. + return None + + # Here, we use "QF_NRA" logic for the solver: + # "Quantifier-free Non-linear Real Arithmetic". + # + # Most of the guards expressions have: + # 1. arithmetic between integer and reals + # 2. no quantifiers + # 3. potentially non-linear. + # + # Although there's also "QF_NIRA" (mixed integer-real arithmetic), + # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'. + solver = z3.SolverFor("QF_NRA") + # Set a timeout for finding a solution. + solver.set(timeout=translation_validation_timeout()) + + # Add all the assertions to the solver. + for assertion in self._assertions: + solver.add(assertion) + + # "Is there any case where it's TRUE for the target expressions, + # but FALSE for the source expressions?" + solver.add(z3.Not(z3.And(*self._source_exprs))) + solver.add(*self._target_exprs) + + log.debug("translation validation: start") + r = solver.check() + if r == z3.sat: + # Target expressions are unsound. + # Log the found model and the source expressions that failed. + model = solver.model() + raise ValidationException( + model, + self._assertions, + self._target_exprs, + failed_source_exprs=[ + inp for inp in self._source_exprs if not model.evaluate(inp) + ], + ) + else: + if r == z3.unknown: + # Could not find a solution. It didn't fail, but it also + # didn't succeed. Canceling the validation execution (keyboard + # interrupt) also gets to this branch. + log.warning( + "translation validation: could not validate: got z3.unknown" + ) + else: + # Target expressions are sound. + assert r == z3.unsat + log.debug("translation validation: success") + +except ImportError: + _HAS_Z3 = False + + __all__ = [ + "translation_validation_enabled", + "translation_validation_timeout", + "ValidationException", + "BisectValidationException", + ] + +else: + _HAS_Z3 = True + + __all__ = [ + "z3str", + "z3op", + "PopulateValidator", + "SympyToZ3", + "TranslationValidator", + "translation_validation_enabled", + "translation_validation_timeout", + "ValidationException", + "BisectValidationException", + ] + +from torch.fx.experimental import _config as config + + +def translation_validation_enabled() -> bool: + # Checks everytime this function is called, in case the Dynamo + # option is set, but Z3 is not installed. + _assert_z3_installed_if_tv_set() + return _HAS_Z3 and config.translation_validation + + +def translation_validation_timeout() -> int: + return config.translation_validation_timeout + + +def _assert_z3_installed_if_tv_set(): + assert _HAS_Z3 or not config.translation_validation, ( + "translation validation requires Z3 package. Please, either install " + "z3-solver or disable translation validation." + ) + + +class ValidationException(TorchDynamoException): + def __init__(self, model, assertions, target_exprs, failed_source_exprs): + assert _HAS_Z3 + + def symbolstr(sym) -> str: + return f"{sym}: {model[sym]}" + + def joinlines(xs) -> str: + return "\n".join(f" ==> {x}" for x in xs) + + model_str = joinlines(sorted(map(symbolstr, model))) + assertions_str = joinlines(sorted(map(z3str, assertions))) + target_exprs_str = joinlines(sorted(map(z3str, target_exprs))) + failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs))) + + self.msg = "translation validation failed." + self.details = f"""\ +Model: +{model_str} + +Assertions: +{assertions_str} + +Target Expressions: +{target_exprs_str} + +Failed Source Expressions: +{failed_source_exprs_str}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +class BisectValidationException(TorchDynamoException): + def __init__(self, validation_exc, expr, failed_action, traced_node): + self.msg = f"translation validation failed when {failed_action}: {expr}" + self.details = f"""\ +Failure occurred while running node: + {traced_node.format_node()} + +{validation_exc.details}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +# Checks when this module is loaded. +_assert_z3_installed_if_tv_set() + + +# Translation validation bisection. +# +# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise +# the earliest ValidationException. +# +# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors +# might be silently happening. This function tries to nail down exactly at which +# point things went wrong from a validation perspective. +def bisect(shape_env): + from torch.fx.experimental.recording import ( + FakeTensorMeta, + replay_shape_env_events, + ShapeEnvEvent, + ) + from torch.fx.experimental.symbolic_shapes import ( + CURRENT_NODE_KEY, + ShapeEnv, + SHAPEENV_EVENT_KEY, + ) + + events = shape_env.events + + # Retrieves the ShapeEnvEvent associated with node. + def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: + assert SHAPEENV_EVENT_KEY in node.meta + return events[node.meta[SHAPEENV_EVENT_KEY]] + + # Creates a new instance of fake, but updating every symbolic value's ShapeEnv + # reference to the one given as argument. + # + # This is needed so as not to simplify a symbolic expression using a ShapeEnv + # "from the future", where it may have a different set of replacements. + def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: + if isinstance(fake, int): + return fake + if isinstance(fake, torch.SymInt): + return torch.SymInt(fake.node.with_shape_env(shape_env)) + if isinstance(fake, torch.SymFloat): + return torch.SymFloat(fake.node.with_shape_env(shape_env)) + assert isinstance(fake, FakeTensorMeta) + return FakeTensorMeta( + tuple(new_with_shape_env(shape_env, s) for s in fake.size()), + tuple(new_with_shape_env(shape_env, s) for s in fake.stride()), + new_with_shape_env(shape_env, fake.storage_offset()), + fake.is_nested, + ) + + # Checks whether the given shape_env fails when produce_guards is called. + def check_shapeenv_fails( + shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]] + ) -> Optional[ValidationException]: + assert tracked_fakes is not None + try: + # This produce_guards call is a best-effort replication, since we + # don't populate EqualityConstraint list. Reason: we would also have + # to save OutputGraph.tracked_fakes_id_to_source. + shape_env.produce_guards( + [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], + [a.source for a in tracked_fakes], + input_contexts=[a.symbolic_context for a in tracked_fakes], + ) + return None + except ValidationException as e: + return e + + # Checks whether the ShapeEnv reconstructed by replaying the events until + # node is created fails when produce_guards is called. + def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: + number = node.meta[SHAPEENV_EVENT_KEY] + # Reconstruct shape_env until the event at event_number. + shape_env = replay_shape_env_events(events[: number + 1]) + shape_env.graph.lint() + return check_shapeenv_fails(shape_env, events[number].tracked_fakes) + + last_exception = check_shapeenv_fails( + shape_env, shape_env._snapshot_tracked_fakes() + ) + + if not last_exception: + # We don't actually fail due to a produce_guards call. + # Stop and don't bisect. + log.info("translation validation succeeded: no errors found.") + return + + if not shape_env.should_record_events or config.translation_validation_no_bisect: + # Bisection is off. + # Return the last ValidationException we got. + raise last_exception + + # Cache the raised exception (if any) at each bisection point. + exception = {} + + # Bisection happens on the assertion nodes of the recorded FX graph for + # dynamic shapes. + assert_nodes = [ + node for node in shape_env.graph.nodes if node.target == torch._assert + ] + + # Preparing the indices for binary search. + # The overall invariants are + # - for all i < left, assert_node[i] doesn't fail + # - for all i >= right, assert_node[i] fails + # - `right in exception` always holds + # - `left <= right` always holds + left, mid, right = 0, 0, len(assert_nodes) - 1 + exception[right] = check_node_fails(assert_nodes[right]) + + while left < right: + mid = (left + right) // 2 + + node = assert_nodes[mid] + log.debug("bisecting at %s: %s", mid, get_node_event(node)) + + # Check whether the new shape_env raises a ValidationException or not. + exception[mid] = check_node_fails(node) + + if exception[mid]: + right = mid + else: + left = mid + 1 + + assert left in exception and isinstance(exception[left], ValidationException) + + node = assert_nodes[left] + event = get_node_event(node) + + if event.is_evaluate_expr(): + failed_action = "evaluating" + else: + assert event.is_defer_runtime_assert(), f"unexpected event type: {event}" + failed_action = "adding runtime assert" + + args = event.args + assert args is not None + assert len(args) >= 2, ( + f"bisecting expects {event.name} to have at least 2 positional arguments. " + f"Got: {len(args)}" + ) + assert isinstance(args[1], sympy.Basic), ( + f"bisecting expects {event.name} to have a SymPy expression as its second argument. " + f"Got: {type(args[1])}" + ) + + raise BisectValidationException( + exception[left], + expr=args[1], + failed_action=failed_action, + traced_node=node.meta[CURRENT_NODE_KEY], + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..433d8818e259a6c0d8d674d15a0312815010ec7f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/__init__.py @@ -0,0 +1,14 @@ +from . import ( + graph_drawer, + graph_manipulation, + net_min_base, + operator_support, + param_fetch, + reinplace, + runtime_assert, + shape_prop, + split_module, + split_utils, + splitter_base, + tools_common, +) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a000f8a99e5871a71953715eb10542b07c3a36 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4978e1f8cd7244049997413a1615f4a7e2634b18 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55817dc40c272fea5ab7e7da2217502679b36e70 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f458bbde129dddedf6a622ca1566cc52173b6783 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f546d6753087f89d86cbadd5c76530c4bb123c9 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..888a1a31bac090a3b273ddd4f48c07483109febb Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b82ca8a212412284693ff083d2d9217ae4f74a35 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95e631be646419a1b51585c6a2164d05e0b7256e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8208b593d60dd7c6cff1139c7b94026abc2b2a8a Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..393ed3f7ed719766d3dcf73f778db9f7dc06cef4 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a00d6d03549128a80b0e4301f98124eca4ba857f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a088f965e3147bff34531307ccfc79fb323eea Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_module.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c160635a8a857842678cb05c59637b70a769d29 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cf330091aa27b7af42a7e37e51f153b6f2b37a9 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb454f2aeb114f29f21b993306db89b3a463986a Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/_tensorify_python_scalars.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/_tensorify_python_scalars.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ae39396619aa1d010e7e4a5bb4e9eea4dcd53d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/_tensorify_python_scalars.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Union + +from sympy import Integer, Number, Symbol +from sympy.logic.boolalg import BooleanAtom + +import torch +import torch.fx as fx +from torch._dynamo.exc import TensorifyScalarRestartAnalysis +from torch._dynamo.symbolic_convert import TensorifyState +from torch._dynamo.utils import get_metrics_context +from torch._prims_common import get_computation_dtype +from torch._subclasses import fake_tensor # noqa: TCH001 +from torch._subclasses.fake_tensor import FakeTensor +from torch._utils_internal import justknobs_check +from torch.fx._utils import lazy_format_graph_code +from torch.fx.experimental.symbolic_shapes import ( # noqa: TCH001 + guard_scalar, + has_free_symbols, + ShapeEnv, +) +from torch.fx.graph_module import GraphModule # noqa: TCH001 + +# TODO: refactor +from torch.fx.passes.runtime_assert import _get_sym_val +from torch.fx.proxy import MetaProxy +from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp +from torch.utils._sympy.reference import TensorReferenceAnalysis +from torch.utils._sympy.symbol import symbol_is_type, SymT + + +__all__: list[str] = [] + +log = logging.getLogger(__name__) +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose") + +# The general shape of this transformation is to look for Tensor operations +# that take a backed SymFloat as an argument, and then redo them as tensor +# compute (with ints and tensors as inputs). For example, add(Tensor, Scalar) +# can be translated into add(Tensor, Tensor). Because Dynamo has already +# arranged for floats to be Tensor inputs to the graph, for typical float +# compute you can entirely translate the Python float operations into Tensor +# operations with only Tensor inputs. +# +# This pass is also responsible for doing CSE on the fly as we do this, since +# you don't want to keep recomputing the same quantity over and over again if +# it's used multiple times. +# +# This pass runs on the JOINT graph produced by AOT Autograd, prior to partitioning. +# The primary goal of this pass is to eliminate floats by replacing TensorScalar +# operations with TensorTensor operations and then Dead Code Elimination (DCE) of +# the item calls, which effectively removes the floats. +# +# This needs to happen before partitioning because it influences partitioning decisions, +# specifically by ensuring that we don't need to save floats across partitions. +# Additionally, there is a separate pass that changes which device computations +# occur on. That pass must be run after this one, but still before partitioning. +# +# HISTORY NOTE: Originally, I wanted to formulate this pass as pushing item() +# calls down, transforming float compute into int compute as we went. If you +# manage to eliminate all float compute, this ends up being equivalent, but +# there is a critical difference when some floats cannot be eliminated: when +# we call item() on them, what should it's SymFloat be? Ideally, it would +# be the same backed SymFloat we had before. But without symbolic expresssion +# propogation on tensor quantities, repropagating would instead give you an +# unbacked SymFloat. Maybe it is a good idea to implement symbolic propagation +# on 0d scalar tensors, but I decided to go for something simpler to start. +# +# The boring stuff: +# +# * What operators can I Tensor-ify? (Anything with a Scalar argument) +# * How do I Tensor-ify a SymFloat sympy expression (Sympy -> Op Handler -> Tensor) +# +# TODO: make sure this runs before CPU->CUDA pass for cudagraph friendliness + + +SUPPORTED_OPS = { + torch.ops.aten.mul.Tensor: torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor: torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor: torch.ops.aten.sub.Tensor, + torch.ops.aten.div.Tensor: torch.ops.aten.div.Tensor, + torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor, + torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor, + torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor, + torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor, + torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, + torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor, +} + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def tensorify_python_scalars( + gm: GraphModule, shape_env: ShapeEnv, fake_mode: fake_tensor.FakeTensorMode +) -> None: + """ + Converts Python scalar operations into Tensor operations within the graph. This pass looks for + Tensor operations that involve SymFloat arguments and transforms them into equivalent operations + that use only Tensor inputs. + + Args: + gm: The FX graph module representing the computation graph. + shape_env: The shape environment responsible for symbolic shape tracking and propagation + during graph transformations. + + Returns: + None + """ + import sympy + + knob = True + if (env := os.getenv("TENSORIFY_PYTHON_SCALARS")) is not None: + if env in ("0", "FALSE"): + knob = False + else: + knob = justknobs_check("pytorch/compiler:tensorify_python_scalars") + if not knob: + return None + + graph = gm.graph + tracer = fx.proxy.GraphAppendingTracer(graph) + expr_to_sym_proxy: dict[sympy.Expr, MetaProxy] = {} + expr_to_tensor_proxy: dict[sympy.Expr, MetaProxy] = {} + tensorified_symbols: set[sympy.Symbol] = set() + should_restart = False + + first_non_placeholder = None + placeholders = set() + for node in graph.nodes: + if node.op != "placeholder": + first_non_placeholder = node + break + else: + placeholders.add(node) + + Analysis = TensorReferenceAnalysis + + def _sympy_interp(expr: sympy.Expr) -> MetaProxy: + # sympy_interp() with hash consing, and special handling for + # generating constants correctly + + # hash cons + if isinstance(expr, Symbol) and expr not in expr_to_tensor_proxy: + # This is guaranteed to be populated by invariant established by + # insert_deferred_runtime_asserts + expr_to_tensor_proxy[expr] = torch.ops.aten.scalar_tensor.default( + expr_to_sym_proxy[expr] + ) + + # cache constants, why not + if isinstance(expr, (Integer, Number, BooleanAtom)): + dtype = None + c: Union[bool, int, float] + if isinstance(expr, BooleanAtom): + dtype = torch.bool + c = bool(expr) + elif isinstance(expr, sympy.Integer): + dtype = torch.int64 + c = int(expr) + elif isinstance(expr, sympy.Number): + dtype = torch.float64 + c = float(expr) + + node = graph.call_function( + torch.ops.aten.scalar_tensor.default, (c,), {"dtype": dtype} + ) + with fake_mode: + node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype) + expr_to_tensor_proxy[expr] = MetaProxy( + node, + tracer=tracer, + fake_mode=fake_mode, + ) + + if expr in expr_to_tensor_proxy: + return expr_to_tensor_proxy[expr] + + # don't cache + if isinstance(expr, Symbol): + return sympy_interp(Analysis, expr_to_tensor_proxy, expr) # type: ignore[arg-type] + + # hash cons on arguments, run expr handler + expr_to_tensor_proxy[expr] = _run_sympy_handler( + Analysis, + [_sympy_interp(arg) for arg in expr.args], # type: ignore[arg-type] + expr, + ) + + return expr_to_tensor_proxy[expr] + + failed_tensorify_ops: set[str] = set() + nodes = list(graph.nodes) + for i, node in enumerate(nodes[:-1]): + with graph.inserting_before( + nodes[i + 1] if node not in placeholders else first_non_placeholder + ): + # Look for tensor.item() calls on placeholders + if ( + node is not None + and node.op == "call_function" + and node.target is torch.ops.aten._local_scalar_dense.default + ): + dtype = node.args[0].meta["val"].dtype + if dtype != torch.float64: + continue + + assert isinstance(node.args[0], fx.Node), node.args[0] + + s = node.meta["val"].node.expr + expr_to_tensor_proxy[s] = MetaProxy( + node.args[0], tracer=tracer, fake_mode=fake_mode + ) + expr_to_sym_proxy[s] = MetaProxy( + node, tracer=tracer, fake_mode=fake_mode + ) + elif (sym_expr := _get_sym_val(node)) is not None: + if sym_expr not in expr_to_sym_proxy and not isinstance( + sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) + ): + expr_to_sym_proxy[sym_expr] = MetaProxy( + node, tracer=tracer, fake_mode=fake_mode + ) + + # Specialize all dimensions that contain symfloats. Here's + # an example test that requires this: + # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950 + val = node.meta.get("val") + if isinstance(val, FakeTensor): + for dim in val.shape: + if isinstance(dim, torch.SymInt): + for s in dim.node.expr.free_symbols: + name = str(s) + if symbol_is_type( + s, SymT.FLOAT + ) and not TensorifyState.should_specialize(name): + # In principle, we could support float input that + # is used to do size compute. The problem is that + # we don't actually want to tensorify the compute + # in this case, which means we need codegen support for + # all symfloats. + TensorifyState.specialize(name) + should_restart = True + + # Look for functions to convert + if node.op == "call_function" and ( + replacement_op := SUPPORTED_OPS.get(node.target) + ): + args: list[Any] = [] + transform = False + compute_dtype = get_computation_dtype(node.meta["val"].dtype) + + for a in node.args: + if ( + isinstance(a, fx.Node) + and "val" in a.meta + and isinstance(zf := a.meta["val"], torch.SymFloat) + ): + transform = True + try: + proxy = _sympy_interp(zf.node.expr) + except NotImplementedError: + transform = False + break + + # We use _expr instead of expr b/c we want the symbol not the replacement + tensorified_symbols.add(a.meta["val"].node._expr) + + # The upcasting is irrelevant when the compute dtype is bool. This happens + # in cases where we are tensorifying a comparison operator such as + # torch.ops.aten.gt.Tensor + if ( + compute_dtype != torch.bool + and proxy.node.meta["val"].dtype != compute_dtype + ): + proxy = torch.ops.prims.convert_element_type.default( + proxy, compute_dtype + ) + + args.append(proxy) + elif isinstance(a, fx.Node): + args.append(MetaProxy(a, tracer=tracer, fake_mode=fake_mode)) + else: + args.append(a) + + if transform: + replacement_proxy = replacement_op(*args) + + if compute_dtype != node.meta["val"].dtype: + replacement_proxy = ( + torch.ops.prims.convert_element_type.default( + replacement_proxy, + node.meta["val"].dtype, + ) + ) + + node.replace_all_uses_with(replacement_proxy.node) + graph.erase_node(node) + + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.set( + "tensorify_float_success", True, overwrite=True + ) + else: + for a in node.args: + if ( + isinstance(a, fx.Node) + and "val" in a.meta + and isinstance(zf := a.meta["val"], torch.SymFloat) + ): + failed_tensorify_ops.update(str(node.target)) + log.info("Failed to tensorify %s", str(node.target)) + + # Now do one more pass that specializes all symfloats we didn't manage + # to tensorify away. + for node in reversed(graph.nodes): + if node.op == "output" or node.op == "placeholder": + continue + + with graph.inserting_before(node): + if len(node.users) == 0 and not node.is_impure(): + graph.erase_node(node) + continue + + if isinstance( + (val := node.meta.get("val")), + (torch.SymFloat, torch.SymInt, torch.SymBool), + ): + if has_free_symbols(val.node.expr) and all( + symbol_is_type(s, SymT.FLOAT) for s in val.node.expr.free_symbols + ): + # If all symbols are backed symfloats, we can just specialize the whole node + # and get more precise guards. eg. + # + # zf = a.item() + # zf2 = zf // 2 + # op(.. zf2 ..) + # + # It's better to guard on zf // 2 == 2.0 than zf == 5.0 + + node.replace_all_uses_with(guard_scalar(val)) + graph.erase_node(node) + + # Sometimes by the time we get to tensorify, there have already been + # specializations, eg. in python_arg_parser.h. In these cases, + # placeholder nodes no longer have a reference to their original + # symfloat and thus we need to deduce specializations have happend + # via shape_env.replacements. NB: there's an important invariant here + # that symfloats keep consistent names across restarts. + for k, v in shape_env.var_to_val.items(): + if symbol_is_type(k, SymT.FLOAT) and isinstance(v, sympy.core.numbers.Float): + name = str(k) + if ( + not TensorifyState.should_specialize(name) + and k not in tensorified_symbols + ): + TensorifyState.specialize(name) + should_restart = True + + if should_restart: + # Sledgehammer time. Restart dynamo analysis, keeping track of which input sources + # are no longer needed and should be specialized. Restarting analysis is necessary + # because we need to instruct Dynamo to NOT make these as inputs. + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.set( + "tensorify_float_failure", failed_tensorify_ops, overwrite=True + ) + metrics_context.set("tensorify_float_success", True, overwrite=True) + raise TensorifyScalarRestartAnalysis + + graph_code_log.debug( + "%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True) + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..0a31a76420b34814a6148fa1ffa21f9b0dc897fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/annotate_getitem_nodes.py @@ -0,0 +1,59 @@ +import operator + +import torch + + +def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: + """ + Annotate the type of getitem nodes, inferred from the type of sequence node. + If sequence node is not annotated with a type, do nothing. + Currently support getitem nodes from tuple, list, and NamedTuple sequence node. + + This is helpful since annotations on local names within function are lost during FX transforms. + Adding back known type annotation for getitem nodes to improve jit scriptability. + + Args: + graph (Graph): The graph to be annotated + """ + for node in graph.nodes: + if node.target == operator.getitem: + sequence_node, index_node = node.args + if not sequence_node.type: + continue + # container types + if hasattr(sequence_node.type, "_name"): + parameterized_types = sequence_node.type.__args__ + if sequence_node.type._name == "Tuple": + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + elif sequence_node.type._name == "List": + assert len(parameterized_types) == 1 + node.type = parameterized_types[0] + # Generic Alias Type + elif hasattr(sequence_node.type, "__origin__"): + parameterized_types = sequence_node.type.__args__ + if sequence_node.type.__origin__ is tuple: + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + elif sequence_node.type.__origin__ is list: + assert len(parameterized_types) == 1 + node.type = parameterized_types[0] + # NamedTuple type + elif hasattr(sequence_node.type, "__annotations__"): + if sequence_node.type == torch.Tensor: + continue + sequence_node_field_types = sequence_node.type.__annotations__ + field_name = sequence_node.type._fields[index_node] + node.type = sequence_node_field_types[field_name] diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..b98178f0d5339321673deffdac6f03a96ffbde45 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/backends/cudagraphs.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.utils import _pytree as pytree + + +class CudaGraphsSupport(OperatorSupport): + # TODO: why is submodules passed here + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op not in CALLABLE_NODE_OPS: + return False + + if node.target in [torch.ops.aten.embedding_dense_backward.default]: + return False + + if node.target in [operator.getitem]: + return True + + found_not_cuda = False + + def meta_fk(meta): + return meta["val"] if "val" in meta else meta["fake_result"] + + def find_not_cuda(t): + nonlocal found_not_cuda + if isinstance(t, torch.Tensor) and t.device.type != "cuda": + found_not_cuda = True + + for n in node.all_input_nodes: + pytree.tree_map_(find_not_cuda, meta_fk(n.meta)) + + pytree.tree_map_(find_not_cuda, meta_fk(node.meta)) + + # NB: factory function is accounted for because the result would be + # cpu or cuda + + return not found_not_cuda + + +def partition_cudagraphs(gm, inputs): + """ + Partition an FX graph into sub-GraphModules that can be validly run under + CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations + must involve CUDA tensors only/ + """ + + FakeTensorProp(gm).propagate(*inputs) + supported_ops = CudaGraphsSupport() + # TODO: single node partition may be wrong due to the pessimization + # from copying in and out the data. Check in benchmarks, perhaps + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=True + ) + partitions = partitioner.propose_partitions() + fused_graph = partitioner.fuse_partitions(partitions) + return fused_graph diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..e5889375bb07ae0f56917aff9950db67ff3f4bec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/dialect/common/cse_pass.py @@ -0,0 +1,155 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch +from torch.fx import Graph, GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._pytree import tree_flatten + + +aten = torch.ops.aten + + +# stateful ops are banned from CSE +rand_ops = { + aten.dropout, + aten._fused_dropout, + aten._standard_gamma, + aten.bernoulli, + aten.multinomial, + aten.native_dropout, + aten.normal, + aten.poisson, + aten.binomial, + aten.rrelu, + aten.rand_like, + aten.rand, + aten.randint, + aten.randn, + aten.randperm, +} # noqa: E501,B950 + +inplace_ops = { + aten.add_, + aten.sub_, + aten.mul_, + aten.div_, + aten.pow_, + aten.lerp_, + aten.relu_, + aten.sigmoid_, + aten.tanh_, +} # noqa: E501 + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def get_CSE_banned_ops(): + return rand_ops.union(inplace_ops) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +class CSEPass(PassBase): + def __init__(self, banned_ops=None): + """ + This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. + + For functional dialects, user would only need to specify the random ops in ban list. + + Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. + If your dialect contains stateful operators, please customized the banned_ops. + + """ + if banned_ops is None: + banned_ops = set() + self.banned_ops = banned_ops + super().__init__() + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Return a new copy of torch.fx.GraphModule with CSE applied to the input graph + + Example usage: + + from torch.fx.experimental.proxy_tensor import make_fx + def f(a): + b = a * a + c = a * a + return b+c + + p = CSEPass() + traced_graph = make_fx(f)(torch.tensor(1)) + print(traced_graph) + result = p(traced_graph) + print(result.graph_module) + """ + + def get_aten_target(node): + if hasattr(node.target, "overloadpacket"): + return node.target.overloadpacket + return node.target + + modified = False + new_graph = Graph() + env: dict[ + Node, Node + ] = {} # map from node in the old graph to node in the new graph + hash_env: dict[ + tuple[torch._ops.OpOverload, int], Node + ] = {} # map from hash to a node in the new graph + token_map: dict[ + tuple[torch._ops.OpOverload, int], dict[str, Any] + ] = {} # map from hash to token + for n in graph_module.graph.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if ( + n.op == "placeholder" + or n.op == "output" + or n.op == "get_attr" + or get_aten_target(n) in self.banned_ops + ): + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, Node) and v in env: + arg_list[i] = env[v] + return tuple(arg_list), spec + + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = { + "target": n.target, + "args": args, + "args_spec": args_spec, + "kwargs": kwargs, + "kwargs_spec": kwargs_spec, + } + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + modified = True # substitution happens and the graph is modified + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + csed_gm = GraphModule(graph_module, new_graph) + return PassResult(csed_gm, modified) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/fake_tensor_prop.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/fake_tensor_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..43dbe86c7370f66aa30b5fbc5853d5a0d12cd8ad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/fake_tensor_prop.py @@ -0,0 +1,109 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch.fx +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake +from torch.fx.node import map_aggregate +from torch.utils._ordered_set import OrderedSet + + +__all__ = ["FakeTensorProp"] + + +@compatibility(is_backward_compatible=False) +class FakeTensorProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and record a fake tensor representing + the metadata for the node. Unlike ShapeProp, (1) this propagation + is cheap--it does the propagation with meta tensors which do not actually + store data, and (2) the fake tensors have much more fine grained information, + e.g., they have accurate alias information that can be consulted by looking + at the storages. + + Args: + module (GraphModule): The module to be executed + mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. + """ + + def __init__( + self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None + ): + super().__init__(module) + if mode is None: + mode = FakeTensorMode() + self._mode = mode + mode.epoch += 1 + mode.reset_nt_tensor_id_counter() + self.seen_subgraphs: OrderedSet[str] = OrderedSet() + + def run_node(self, n: Node): + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + ) + + if ( + n.op == "call_function" + and n.target is torch.ops.higher_order.invoke_subgraph + and n.args[1] not in self.seen_subgraphs + ): + # Prevent redundant fake tensor prop for invoke_subgraphs. Note that + # there is also fake tensor caching for the entire subgraph. This + # happens the next time we call `run_node` for the same subgraph, + # which goes through super.run_node and caches the fake tensor prop. + # Therefore, we are propagating fake tensor through the subgraphs + # twice. + assert isinstance(n.args[1], str) + assert ( + isinstance(n.args[0], torch.fx.Node) + and n.args[0].op == "get_attr" + and isinstance(n.args[0].target, str) + ) + self.seen_subgraphs.add(n.args[1]) + operands = n.args[2:] + example_inputs = [] + for operand in operands: + assert isinstance(operand, torch.fx.Node) and "val" in operand.meta + example_inputs.append(operand.meta["val"]) + return FakeTensorProp( + getattr(self.module, n.args[0].target), mode=self._mode + ).propagate(*example_inputs) + + result = super().run_node(n) + rebind_unbacked(self._mode.shape_env, n, result) + + def extract_val(obj): + if isinstance(obj, FakeTensor): + return snapshot_fake(obj) + elif isinstance(obj, torch.Tensor): + # TODO: How is it possible that we get a non fake tensor? We + # should be running under the mode... + return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True)) + elif isinstance(obj, py_sym_types): + return obj + else: + return None + + meta = map_aggregate(result, extract_val) + if meta is not None: + n.meta["val"] = meta + if (shape_env := self._mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, result) + ): + n.meta["unbacked_bindings"] = symbol_to_path + + return result + + def propagate(self, *args): + fake_args = [ + self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a + for a in args + ] + return self.propagate_dont_convert_inputs(*fake_args) + + def propagate_dont_convert_inputs(self, *args): + with self._mode: + return super().run(*args) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py new file mode 100644 index 0000000000000000000000000000000000000000..a5445a6851fa99e38c484f53ee5b4e2b85f5369d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_drawer.py @@ -0,0 +1,501 @@ +# mypy: allow-untyped-defs + +import hashlib +from itertools import chain +from types import ModuleType +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import _parse_stack_trace +from torch.fx.node import _format_arg, _get_qualified_name +from torch.fx.operator_schemas import normalize_function +from torch.fx.passes.shape_prop import TensorMetadata + + +if TYPE_CHECKING: + import pydot + + HAS_PYDOT = True +else: + pydot: Optional[ModuleType] + try: + import pydot + + HAS_PYDOT = True + except ModuleNotFoundError: + HAS_PYDOT = False + pydot = None + + +__all__ = ["FxGraphDrawer"] + +_COLOR_MAP = { + "placeholder": '"AliceBlue"', + "call_module": "LemonChiffon1", + "get_param": "Yellow2", + "get_attr": "LightGrey", + "output": "PowderBlue", +} + +_HASH_COLOR_MAP = [ + "CadetBlue1", + "Coral", + "DarkOliveGreen1", + "DarkSeaGreen1", + "GhostWhite", + "Khaki1", + "LavenderBlush1", + "LightSkyBlue", + "MistyRose1", + "MistyRose2", + "PaleTurquoise2", + "PeachPuff1", + "Salmon", + "Thistle1", + "Thistle3", + "Wheat1", +] + +_WEIGHT_TEMPLATE = { + "fillcolor": "Salmon", + "style": '"filled,rounded"', + "fontcolor": "#000000", +} + +if HAS_PYDOT: + + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + """ + Visualize a torch.fx.Graph with graphviz + Basic usage: + g = FxGraphDrawer(symbolic_traced, "resnet18") + g.get_dot_graph().write_svg("a.svg") + """ + + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + normalize_args: bool = False, + ): + self._name = name + self.dot_graph_shape = ( + dot_graph_shape if dot_graph_shape is not None else "record" + ) + self.normalize_args = normalize_args + _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape + + self._dot_graphs = { + name: self._to_dot( + graph_module, + name, + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, + ) + } + + for node in graph_module.graph.nodes: + if node.op != "call_module": + continue + + leaf_node = self._get_leaf_node(graph_module, node) + + if not isinstance(leaf_node, torch.fx.GraphModule): + continue + + self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( + leaf_node, + f"{name}_{node.target}", + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, + ) + + def get_dot_graph(self, submod_name=None) -> pydot.Dot: + """ + Visualize a torch.fx.Graph with graphviz + Example: + >>> # xdoctest: +REQUIRES(module:pydot) + >>> # xdoctest: +REQUIRES(module:ubelt) + >>> # define module + >>> class MyModule(torch.nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.linear = torch.nn.Linear(4, 5) + >>> def forward(self, x): + >>> return self.linear(x).clamp(min=0.0, max=1.0) + >>> module = MyModule() + >>> # trace the module + >>> symbolic_traced = torch.fx.symbolic_trace(module) + >>> # setup output file + >>> import ubelt as ub + >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir() + >>> fpath = dpath / "linear.svg" + >>> # draw the graph + >>> g = FxGraphDrawer(symbolic_traced, "linear") + >>> g.get_dot_graph().write_svg(fpath) + """ + if submod_name is None: + return self.get_main_dot_graph() + else: + return self.get_submod_dot_graph(submod_name) + + def get_main_dot_graph(self) -> pydot.Dot: + return self._dot_graphs[self._name] + + def get_submod_dot_graph(self, submod_name) -> pydot.Dot: + return self._dot_graphs[f"{self._name}_{submod_name}"] + + def get_all_dot_graphs(self) -> dict[str, pydot.Dot]: + return self._dot_graphs + + def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]: + template = { + "shape": self.dot_graph_shape, + "fillcolor": "#CAFFE3", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + if node.op in _COLOR_MAP: + template["fillcolor"] = _COLOR_MAP[node.op] + else: + # Use a random color for each node; based on its name so it's stable. + target_name = node._pretty_print_target(node.target) + target_hash = int( + hashlib.md5( + target_name.encode(), usedforsecurity=False + ).hexdigest()[:8], + 16, + ) + template["fillcolor"] = _HASH_COLOR_MAP[ + target_hash % len(_HASH_COLOR_MAP) + ] + return template + + def _get_leaf_node( + self, module: torch.nn.Module, node: torch.fx.Node + ) -> torch.nn.Module: + py_obj = module + assert isinstance(node.target, str) + atoms = node.target.split(".") + for atom in atoms: + if not hasattr(py_obj, atom): + raise RuntimeError( + str(py_obj) + " does not have attribute " + atom + "!" + ) + py_obj = getattr(py_obj, atom) + return py_obj + + def _typename(self, target: Any) -> str: + if isinstance(target, torch.nn.Module): + ret = torch.typename(target) + elif isinstance(target, str): + ret = target + else: + ret = _get_qualified_name(target) + + # Escape "{" and "}" to prevent dot files like: + # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc + # which triggers `Error: bad label format (...)` from dot + return ret.replace("{", r"\{").replace("}", r"\}") + + # shorten path to avoid drawing long boxes + # for full path = '/home/weif/pytorch/test.py' + # return short path = 'pytorch/test.py' + def _shorten_file_name( + self, + full_file_name: str, + truncate_to_last_n: int = 2, + ): + splits = full_file_name.split("/") + if len(splits) >= truncate_to_last_n: + return "/".join(splits[-truncate_to_last_n:]) + return full_file_name + + def _get_node_label( + self, + module: torch.fx.GraphModule, + node: torch.fx.Node, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> str: + def _get_str_for_args_kwargs(arg): + if isinstance(arg, tuple): + prefix, suffix = r"|args=(\l", r",\n)\l" + arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] + elif isinstance(arg, dict): + prefix, suffix = r"|kwargs={\l", r",\n}\l" + arg_strs_list = [ + f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items() + ] + else: # Fall back to nothing in unexpected case. + return "" + + # Strip out node names if requested. + if skip_node_names_in_args: + arg_strs_list = [a for a in arg_strs_list if "%" not in a] + if len(arg_strs_list) == 0: + return "" + arg_strs = prefix + r",\n".join(arg_strs_list) + suffix + if len(arg_strs_list) == 1: + arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") + return arg_strs.replace("{", r"\{").replace("}", r"\}") + + label = "{" + f"name=%{node.name}|op_code={node.op}\n" + + if node.op == "call_module": + leaf_module = self._get_leaf_node(module, node) + label += r"\n" + self._typename(leaf_module) + r"\n|" + extra = "" + if hasattr(leaf_module, "__constants__"): + extra = r"\n".join( + [ + f"{c}: {getattr(leaf_module, c)}" + for c in leaf_module.__constants__ # type: ignore[union-attr] + ] # type: ignore[union-attr] + ) + label += extra + r"\n" + else: + label += f"|target={self._typename(node.target)}" + r"\n" + if self.normalize_args: + try: + args, kwargs = normalize_function( # type: ignore[misc] + node.target, # type: ignore[arg-type] + node.args, # type: ignore[arg-type] + node.kwargs, + normalize_to_only_use_kwargs=True, + ) + except Exception: + # Fallback to not normalizing if there's an exception. + # Some functions need overloads specified to normalize. + args, kwargs = node.args, node.kwargs + else: + args, kwargs = node.args, node.kwargs + if len(args) > 0: + label += _get_str_for_args_kwargs(args) + if len(kwargs) > 0: + label += _get_str_for_args_kwargs(kwargs) + label += f"|num_users={len(node.users)}" + r"\n" + + tensor_meta = node.meta.get("tensor_meta") + label += self._tensor_meta_to_label(tensor_meta) + + # for original fx graph + # print buf=buf0, n_origin=6 + buf_meta = node.meta.get("buf_meta", None) + if buf_meta is not None: + label += f"|buf={buf_meta.name}" + r"\n" + label += f"|n_origin={buf_meta.n_origin}" + r"\n" + + # for original fx graph + # print file:lineno code + if parse_stack_trace and node.stack_trace is not None: + parsed_stack_trace = _parse_stack_trace(node.stack_trace) + fname = self._shorten_file_name(parsed_stack_trace.file) + label += ( + f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + + r"\n" + ) + + return label + "}" + + def _tensor_meta_to_label(self, tm) -> str: + if tm is None: + return "" + elif isinstance(tm, TensorMetadata): + return self._stringify_tensor_meta(tm) + elif isinstance(tm, list): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + elif isinstance(tm, dict): + result = "" + for v in tm.values(): + result += self._tensor_meta_to_label(v) + return result + elif isinstance(tm, tuple): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + else: + raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") + + def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: + result = "" + if not hasattr(tm, "dtype"): + print("tm", tm) + result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" + result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" + result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" + result += "|" + "stride" + "=" + str(tm.stride) + r"\n" + if tm.is_quantized: + assert tm.qparams is not None + assert "qscheme" in tm.qparams + qscheme = tm.qparams["qscheme"] + if qscheme in { + torch.per_tensor_affine, + torch.per_tensor_symmetric, + }: + result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" + result += ( + "|" + + "q_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + }: + result += ( + "|" + + "q_per_channel_scale" + + "=" + + str(tm.qparams["scale"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_axis" + + "=" + + str(tm.qparams["axis"]) + + r"\n" + ) + else: + raise RuntimeError(f"Unsupported qscheme: {qscheme}") + result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" + return result + + def _get_tensor_label(self, t: torch.Tensor) -> str: + return str(t.dtype) + str(list(t.shape)) + r"\n" + + # when parse_stack_trace=True + # print file:lineno code + def _to_dot( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool, + ignore_parameters_and_buffers: bool, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> pydot.Dot: + """ + Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. + If ignore_parameters_and_buffers is True, the parameters and buffers + created with the module will not be added as nodes and edges. + """ + + # "TB" means top-to-bottom rank direction in layout + dot_graph = pydot.Dot(name, rankdir="TB") + + buf_name_to_subgraph = {} + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + style = self._get_node_style(node) + dot_node = pydot.Node( + node.name, + label=self._get_node_label( + graph_module, node, skip_node_names_in_args, parse_stack_trace + ), + **style, # type: ignore[arg-type] + ) + + current_graph = dot_graph + + buf_meta = node.meta.get("buf_meta", None) + if buf_meta is not None and buf_meta.n_origin > 1: + buf_name = buf_meta.name + if buf_name not in buf_name_to_subgraph: + buf_name_to_subgraph[buf_name] = pydot.Cluster( + buf_name, label=buf_name + ) + current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment] + + current_graph.add_node(dot_node) + + def get_module_params_or_buffers(): + for pname, ptensor in chain( + leaf_module.named_parameters(), leaf_module.named_buffers() + ): + pname1 = node.name + "." + pname + label1 = ( + pname1 + "|op_code=get_" + "parameter" + if isinstance(ptensor, torch.nn.Parameter) + else "buffer" + r"\l" + ) + dot_w_node = pydot.Node( + pname1, + label="{" + label1 + self._get_tensor_label(ptensor) + "}", + **_WEIGHT_TEMPLATE, # type: ignore[arg-type] + ) + dot_graph.add_node(dot_w_node) + dot_graph.add_edge(pydot.Edge(pname1, node.name)) + + if node.op == "call_module": + leaf_module = self._get_leaf_node(graph_module, node) + + if not ignore_parameters_and_buffers and not isinstance( + leaf_module, torch.fx.GraphModule + ): + get_module_params_or_buffers() + + for subgraph in buf_name_to_subgraph.values(): + subgraph.set("color", "royalblue") + subgraph.set("penwidth", "2") + dot_graph.add_subgraph(subgraph) # type: ignore[arg-type] + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + for user in node.users: + dot_graph.add_edge(pydot.Edge(node.name, user.name)) + + return dot_graph + +else: + if not TYPE_CHECKING: + + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + normalize_args: bool = False, + ): + raise RuntimeError( + "FXGraphDrawer requires the pydot package to be installed. Please install " + "pydot through your favorite Python package manager." + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_manipulation.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..f559aa0bfcb3d96733f479d864aaab40923c473c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_manipulation.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs +from typing import Any, NamedTuple, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import map_arg, Node, Target +from torch.fx.passes.shape_prop import ShapeProp + + +__all__ = [ + "replace_target_nodes_with", + "size_bytes", + "get_size_of_all_nodes", + "get_tensor_meta", + "get_size_of_node", +] + + +@compatibility(is_backward_compatible=False) +def replace_target_nodes_with( + fx_module: GraphModule, + old_op: str, + old_target: Target, + new_op: str, + new_target: Target, +): + """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, + and updates them to match the new op code and target""" + new_graph = Graph() + val_map: dict[Node, Node] = {} + for node in fx_module.graph.nodes: + if node.op == old_op and node.target == old_target: + args = map_arg(node.args, lambda n: val_map[n]) + kwargs = map_arg(node.kwargs, lambda n: val_map[n]) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + val_map[node] = new_graph.create_node( + new_op, new_target, args, kwargs, node.name + ) + else: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + fx_module.graph = new_graph + + +@compatibility(is_backward_compatible=False) +class size_bytes(NamedTuple): + output_size: int + total_size: int + + +@compatibility(is_backward_compatible=False) +def get_size_of_all_nodes( + fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None +) -> None: + """Given a fx graph module, update each node with its total size (weights + bias + output) + and its output_size(output). For a non-module node, the total size is the output size. + return total size""" + if args is not None: + # Mark shape and dtype for each node (node.shape and node.dtype) + ShapeProp(fx_module).propagate(*args) + # Calculate the total size of the whole fx graph + for node in fx_module.graph.nodes: + if node.op == "output": + break + node.size_bytes = get_size_of_node(fx_module, node) + return + + +@compatibility(is_backward_compatible=False) +def get_tensor_meta(node: Node) -> Any: + tensor_meta = node.meta.get("tensor_meta") + + if not tensor_meta: + raise RuntimeError( + f"Node {node} has no tensor metadata associated with it! " + f"Check that shape propagation has run." + ) + + return tensor_meta + + +@compatibility(is_backward_compatible=False) +def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: + """Given a node with node.dtype and node.shape, return its total size and its output size. + total_size = weights + bias + output_size + """ + # Total num of elements + total_num_of_elems = 0 + # For a module, conside all parameters + if node.op == "call_module": + submodule_dict = dict(fx_module.named_modules()) + submodule = submodule_dict[node.target] + parameters = submodule.named_parameters() + # Parameters are named tuples + for _name, p in parameters: + total_num_of_elems += p.numel() + # Don't forget the output size + # node.shape is the shape of this node's output + tensor_meta = get_tensor_meta(node) + output_elem = tensor_meta.shape.numel() + total_num_of_elems += output_elem + # Assume for now if it's quantized then it's qint8 or quint8 + if tensor_meta.is_quantized: + size_per_elem_bytes = torch._empty_affine_quantized( + [], dtype=tensor_meta.dtype + ).element_size() + else: + size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size() + total_size = size_per_elem_bytes * total_num_of_elems + output_size = size_per_elem_bytes * output_elem + return size_bytes(output_size, total_size) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..e19abc7ad3d8b726adeb22082b4c11dfe1cbefc2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py @@ -0,0 +1,219 @@ +# mypy: allow-untyped-defs +import os +from typing import Callable, Optional, TypeVar + +from torch.fx import Graph, Node +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.traceback import NodeSource, NodeSourceAction + + +T = TypeVar("T") + + +from .graph_drawer import FxGraphDrawer + + +__all__ = ["GraphTransformObserver"] + + +@compatibility(is_backward_compatible=False) +class GraphTransformObserver: + __pass_count = 0 + + def __init__( + self, + gm: GraphModule, + passname: str, + subsystem: Optional[str] = None, + log_url: Optional[str] = None, + ): + """ + log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified + """ + from torch._inductor.config import trace + + self.gm = gm + self.passname = passname + self.subsystem = subsystem + + if log_url is None: + log_url = trace.log_url_for_graph_xform + + self.log_url = log_url + + self.active = trace.enabled or self.log_url is not None + + if self.active: + self.erased_nodes: set[str] = set() + self.created_nodes: set[str] = set() + self.name_to_node: dict[str, Node] = {} + # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context + self.copied_gms: list[GraphModule] = [] + + self._node_creation_hook = self.get_node_creation_hook() + self._node_erase_hook = self.get_node_erase_hook() + self._node_replace_hook = self.get_node_replace_hook() + self._deepcopy_hook = self.get_deepcopy_hook() + + # If log_url is None, we don't log anything + if self.log_url is None: + return + GraphTransformObserver.__pass_count += 1 + + self.input_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + + @classmethod + def get_current_pass_count(cls): + return cls.__pass_count + + def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]: + with self: + if not self._check_disable_pass(): + return pass_fn(self.gm) + + return None + + def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]: + with self: + if not self._check_disable_pass(): + return pass_fn(self.gm.graph) + + return None + + def _check_disable_pass(self): + if self.subsystem is None: + return False + + debug_info = lambda: self.passname # noqa: E731 + from torch._inductor.compiler_bisector import CompilerBisector + + return CompilerBisector.disable_subsystem( + "inductor", self.subsystem, debug_info + ) + + def __enter__(self): + if not self.active: + return self + self.gm._register_create_node_hook(self._node_creation_hook) + self.gm._register_erase_node_hook(self._node_erase_hook) + self.gm._register_replace_node_hook(self._node_replace_hook) + self.gm._register_deepcopy_hook(self._deepcopy_hook) + + self.erased_nodes.clear() + self.created_nodes.clear() + self.name_to_node.clear() + self.copied_gms.clear() + + for node in self.gm.graph.nodes: + self.name_to_node[node.name] = node + + return self + + def __exit__(self, type, value, tb): + if not self.active: + return + for gm in self.copied_gms + [self.gm]: + gm._unregister_create_node_hook(self._node_creation_hook) + gm._unregister_erase_node_hook(self._node_erase_hook) + gm._unregister_replace_node_hook(self._node_replace_hook) + gm._unregister_deepcopy_hook(self._deepcopy_hook) + + if self.log_url is None: + return + + if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0: + for e in self.input_dot_graph.get_node_list(): + if e.get_name() in self.erased_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + assert self.log_url is not None + self.input_dot_graph.write( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot", + ) + ) + + output_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + for e in output_dot_graph.get_node_list(): + if e.get_name() in self.created_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + output_dot_graph.write( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot", + ) + ) + + def get_node_creation_hook(self): + # We have to return a function instead of using a class method directly + # to avoid max recursion issue when deepcopy a graph module within the context manager. + def on_node_creation(node): + self.created_nodes.add(node.name) + self.name_to_node[node.name] = node + source = NodeSource(None, self.passname, NodeSourceAction.CREATE) + if "from_node" not in node.meta: + node.meta["from_node"] = [source] + else: + node.meta["from_node"].append(source) + + return on_node_creation + + def get_node_erase_hook(self): + def on_node_erase(node): + self.erased_nodes.add(node.name) + self.name_to_node.pop(node.name, None) + + return on_node_erase + + def get_node_replace_hook(self): + def on_node_replace(old: Node, new: str, user: Node): + # Update node meta when replacing old node with new node + new_node = self.name_to_node.get(new, None) + + if not new_node: + return + + assert isinstance(new_node, Node) + + action = [NodeSourceAction.REPLACE] + if new_node.name in self.created_nodes: + action.append(NodeSourceAction.CREATE) + + def created_this_pass(source): + return source.pass_name == self.passname and source.action == [ + NodeSourceAction.CREATE + ] + + # remove redundant source added on node creation + new_from_node = new_node.meta.get("from_node", []) + new_from_node = [ + source for source in new_from_node if not created_this_pass(source) + ] + + # add new source + new_node_source = NodeSource(old, self.passname, action) + new_from_node.append(new_node_source) + new_node.meta["from_node"] = new_from_node + + return on_node_replace + + def get_deepcopy_hook(self): + def on_deepcopy(gm): + self.copied_gms.append(gm) + + return on_deepcopy diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..939157f1302e75e3cf17ec3c1e93d1b8993d67a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__init__.py @@ -0,0 +1 @@ +from . import pass_manager diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9575e2ad5ff04126ea75795e496548b359c47fd0 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00f9ff678f2bf47ee7bc1dda25139760ba1320ad Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95a5f23040e85e2251aac21eae37e2ee2785426c Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..438661090942a876fa88b83a485791d37b591a61 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/partitioner.py @@ -0,0 +1,376 @@ +# mypy: allow-untyped-defs +import collections +import itertools +import logging +import operator +from collections.abc import Iterable, Sequence +from typing import Optional + +from torch.fx.graph_module import GraphModule +from torch.fx.node import _get_qualified_name, Node +from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +class Partition: + def __init__( + self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + ): + self.id = id + self.nodes = dict.fromkeys(nodes) if nodes is not None else {} + + def __repr__(self) -> str: + return str(self.nodes) + + def add_node(self, node: Node): + self.nodes.update({node: None}) + + def remove_node(self, node: Node): + del self.nodes[node] + + def size(self): + return len(self.nodes) + + +class _DependencyViewer: + def __init__(self, graph_module: GraphModule): + self.downstreams = collections.defaultdict(set) + + for node in reversed(graph_module.graph.nodes): + for output_node in node.users: + # add output_node and output_node's downstream dependency + self.downstreams[node].add(output_node) + self.downstreams[node].update(self.downstreams[output_node]) + + def downstreams_of(self, node: Node) -> set[Node]: + return self.downstreams[node] + + +class CapabilityBasedPartitioner: + def __init__( + self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + ) -> None: + self.graph_module = graph_module + self.operator_support = operator_support + self.allows_single_node_partition = allows_single_node_partition + self.non_compute_ops = non_compute_ops if non_compute_ops is not None else [] + self.allowed_single_node_partition_ops = ( + allowed_single_node_partition_ops + if allowed_single_node_partition_ops is not None + else [] + ) + self.dependency_viewer = _DependencyViewer(graph_module) + + def _is_node_supported(self, node: Node) -> bool: + return self.operator_support.is_node_supported( + dict(self.graph_module.named_modules()), node + ) + + def propose_partitions(self) -> list[Partition]: + # partition_map is a mapping from partition id to a set of partition id's. + # The value set contains all the partition ids that can be reached by doing a + # DFS starting from the partition id in the key. + partition_map: dict[int, set] = collections.defaultdict(set) + + # assumptions: nodes in candidate list is sorted in topological order + assignment: dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: dict[ + int, Partition + ] = {} # mapping from partition_id to partition + nodes_order: dict[ + Node, int + ] = {} # mapping from nodes to reversed topological order + partitions_order: dict[ + int, int + ] = {} # mapping from partition_id to minimum topo order of nodes in partition + partition_users: dict[ + int, set + ] = {} # mapping from partition_id to partition users + new_partition_id = itertools.count() + + # try to merge partition other_id into partition self_id + # merge only happens if the end graph doesn't contain cyclic dependency + # returns `True` when merge happens, `False` otherwise. + def maybe_merge_partition(self_id: int, other_id: int): + # merged_nodes is the union of nodes in two partition to-be-merged + self_nodes = partitions_by_id[self_id].nodes + other_nodes = partitions_by_id[other_id].nodes + + def dfs_iter_find_cycle(all_user_nodes: set[Node]): + for user_node in all_user_nodes: + visited_partition_ids = set() + + for path_node in self.dependency_viewer.downstreams_of(user_node): + # If any of the nodes in the dfs path of this node are in the merged_nodes + # list then there is a cycle in the graph. + if path_node in self_nodes or path_node in other_nodes: + return True + + # If any of the nodes in the dfs path of this node are in the assignment + # map then we have to make sure that the partitions that these nodes belong + # to do not form a cycle with the current partitions being merged. This means + # iterating through all the nodes in all the parititons that are traversed in + # the dfs path and checking if they are in the merged_nodes list. + if path_node in assignment: + partition_id = assignment[path_node] + # If the partition id has already been visited then we know that it doesn't + # form a cycle with the current partitions being merged. + if partition_id in visited_partition_ids: + continue + p_map = partition_map[partition_id] + if self_id in p_map or other_id in p_map: + return True + + visited_partition_ids.add(partition_id) + + return False + + # find new partition users if merge. + all_user_nodes = partition_users[self_id] | partition_users[other_id] + all_user_nodes.difference_update(other_nodes, self_nodes) + + # check if merge would create cyclic dependency. + if dfs_iter_find_cycle(all_user_nodes): + # return false indicating cyclic dependency found and + # merge is aborted + return self_id, False + + # merge the smaller partition into the larger. + merge_id, removed_id = self_id, other_id + if len(self_nodes) < len(other_nodes): + merge_id, removed_id = removed_id, merge_id + # no cyclic dependency found, move forward with the merge + # updating partition nodes + partitions_by_id[merge_id].nodes.update(partitions_by_id[removed_id].nodes) + # updating assignment map + for node in partitions_by_id[removed_id].nodes: + assignment[node] = merge_id + # delete other partition + del partitions_by_id[removed_id] + + partitions_order[merge_id] = min( + partitions_order[merge_id], partitions_order[removed_id] + ) + del partitions_order[removed_id] + + partition_map[merge_id] = partition_map[merge_id].union( + partition_map[removed_id] + ) + del partition_map[removed_id] + + partition_users[merge_id] = all_user_nodes + del partition_users[removed_id] + + return merge_id, True + + def merge_single_node(node: Node, id: Optional[int]): + def _update_partition_map(node: Node, id: int): + # Iterate through all the users of this node and update the partition map to indicate + # that there is a path from the partition id of this node to the target partition id. + for user_node in node.users: + target_id = assignment.get(user_node, None) + if target_id is not None: + partition_map[id].add(target_id) + partition_map[id].update(partition_map[target_id]) + + if node in assignment: + partitions_by_id[assignment[node]].remove_node(node) + + if id is None: + assignment.pop(node) + elif id not in partitions_by_id: + assignment[node] = id + partitions_by_id[id] = Partition(id=id, nodes=[node]) + partition_users[id] = set(node.users) + _update_partition_map(node, id) + else: + assignment[node] = id + partitions_by_id[id].add_node(node) + + logger.debug("Proposing partitions...") + + for node in reversed(self.graph_module.graph.nodes): + # use Dict as an ordered set to ensure deterministic partitioning result, don't care value + merge_candidates: dict[int, None] = {} + + # Note a limited horizontal fusion is enabled: + # when `node` is not supported, the code below attempts to fuse consumer of `node`. + # + # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut + # the fusion by adding an `else` block here to skip horizontal fusion. + if self._is_node_supported(node) and node not in assignment: + partition_id = next(new_partition_id) + nodes_order[node] = partition_id + partitions_order[partition_id] = partition_id + merge_single_node(node, partition_id) + merge_candidates[partition_id] = None + + # merge all possible partitions + for partition_id, _ in sorted( + partitions_order.items(), key=operator.itemgetter(1) + ): + merge_candidates[partition_id] = None + + merge_candidates_list = list(merge_candidates.keys()) + if len(merge_candidates_list) > 1: + self_id = merge_candidates_list[0] + for other_id in merge_candidates_list[1:]: + # note: merge partitions if it doesn't create cyclic dependency + # in the graph, otherwise, this is a no-op + self_id, _ = maybe_merge_partition(self_id, other_id) + + # post processing to re-assign "getitem" nodes into upstream partition + logger.debug("Reassigning getitem nodes to its producer node's partition...") + nodes_reassignment: dict[Node, int] = {} + for node in self.graph_module.graph.nodes: + is_tuple_output = True + for user in node.users: + if ( + user.op != "call_function" + or _get_qualified_name(user.target) != "_operator.getitem" + ): # type: ignore[arg-type] + is_tuple_output = False + break + + # node has tuple outputs, re-assign all following getitem node into node's partition + if is_tuple_output: + id = assignment.get(node, None) # type: ignore[arg-type] + for user in node.users: + if assignment.get(user, None) != id: # type: ignore[arg-type] + nodes_reassignment[user] = id # type: ignore[assignment] + for node, id in nodes_reassignment.items(): + merge_single_node(node, id) + + # filter out single node partitions + if not self.allows_single_node_partition: + logger.debug("Filtering out single node partitions...") + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) + partitions_to_remove: list[int] = [] + for id, partition in partitions_by_id.items(): + compute_node_count = 0 + for node in partition.nodes: + if node.op == "call_function": + assert callable(node.target) + if _get_qualified_name(node.target) not in non_compute_ops: + compute_node_count += 1 + if ( + _get_qualified_name(node.target) + in self.allowed_single_node_partition_ops + ): + compute_node_count += 1 + if compute_node_count <= 1: + partitions_to_remove.append(id) + for id in partitions_to_remove: + del partitions_by_id[id] + + logger.debug("Partitions proposed:") + for id, partition in partitions_by_id.items(): + logger.debug( + "partition #%s: %s", id, [node.name for node in partition.nodes] + ) + + return [ + partition for partition in partitions_by_id.values() if partition.size() > 0 + ] + + def fuse_partitions( + self, partitions: list[Partition], prefix: str = "fused_" + ) -> GraphModule: + logger.debug("Fusing partitions...") + # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ] + return fuse_by_partitions( + self.graph_module, + [partition.nodes for partition in partitions], + prefix=prefix, + ) + + # remove non-compute-ops that sits at the boundary of a partition. + def remove_bookend_non_compute_ops(self, partitions: list[Partition]): + non_compute_ops = set(self.non_compute_ops) + + def is_non_compute_node(node: Node): + return ( + node.op == "call_function" + and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + ) + + # cache transparent nodes + transparent_input_nodes: dict[Node, bool] = {} + transparent_output_nodes: dict[Node, bool] = {} + + def is_transparent_input_node( + node: Node, partition: set[Node], removed_nodes: set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): + return True + if node in transparent_input_nodes: + return transparent_input_nodes[node] + if is_non_compute_node(node): + for input_n in node.all_input_nodes: + if not is_transparent_input_node(input_n, partition, removed_nodes): + transparent_input_nodes[node] = False + return False + transparent_input_nodes[node] = True + return True + transparent_input_nodes[node] = False + return False + + def is_transparent_output_node( + node: Node, partition: set[Node], removed_nodes: set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): + return True + if node in transparent_output_nodes: + return transparent_output_nodes[node] + if is_non_compute_node(node): + for output_n in node.users: + if not is_transparent_output_node( + output_n, partition, removed_nodes + ): + transparent_output_nodes[node] = False + return False + transparent_output_nodes[node] = True + return True + transparent_output_nodes[node] = False + return False + + for partition in partitions: + # Note it's ok to use `set` here, since we are only query if a node + # has been removed. We are NEVER going to iterate on nodes inside + # the set. + remove_node: set[Node] = set() + for node in partition.nodes: + if is_non_compute_node(node) and ( + is_transparent_input_node(node, set(partition.nodes), remove_node) + or is_transparent_output_node( + node, set(partition.nodes), remove_node + ) + ): + remove_node.add(node) + + if len(remove_node) != 0: + for node in remove_node: + partition.nodes.pop(node, None) + + def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions, prefix=prefix) + return fused_gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..957b8145f995dedb7d40f7d63ba555d40173a53d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_base.py @@ -0,0 +1,78 @@ +# mypy: allow-untyped-defs +import abc +from collections import namedtuple +from typing import Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +__all__ = ["PassResult", "PassBase"] + + +@compatibility(is_backward_compatible=False) +class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): + """ + Result of a pass: + graph_module: The modified graph module + modified: A flag for if the pass has modified the graph module + """ + + __slots__ = () + + def __new__(cls, graph_module, modified): + return super().__new__(cls, graph_module, modified) + + +@compatibility(is_backward_compatible=False) +class PassBase(abc.ABC): + """ + Base interface for implementing passes. + + It is required to implement the `call` function so that we can directly + pass instances of the Pass directly to the PassManager and call them as a + function. + + We can directly pass an instance of a class implementing this interface into + the PassManager's `passes` attribute. + """ + + def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(graph_module) + res = self.call(graph_module) + self.ensures(graph_module) + return res + + @abc.abstractmethod + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + The pass that is run through the given graph module. To implement a + pass, it is required to implement this function. + + Args: + graph_module: The graph module we will run a pass on + """ + + def requires(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given graph module contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + + def ensures(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given graph module contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..68753d9351f103003c9a8fcac402900ad63d1658 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/infra/pass_manager.py @@ -0,0 +1,309 @@ +# mypy: allow-untyped-defs +import inspect +import logging +from functools import wraps +from queue import Queue +from typing import Callable + +import torch.nn as nn +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.passes.infra.pass_base import PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"] + + +@compatibility(is_backward_compatible=False) +def pass_result_wrapper(fn: Callable) -> Callable: + """ + Wrapper for passes which currently do not return a PassResult. + This wrapper makes them return a PassResult containing the modified object + and True for the "modified" flag. + + Args: + fn (Callable[Module, Any]) + + Returns: + wrapped_fn (Callable[Module, PassResult]) + """ + if fn is None: + return None + + @wraps(fn) + def wrapped_fn(gm): + res = fn(gm) + if res is None: + return PassResult(gm, True) + if isinstance(res, PassResult): + return res + elif isinstance(res, nn.Module): + return PassResult(res, True) + + if not inspect.isfunction(fn): + wrapped_fn.__name__ = type(fn).__name__ + + return wrapped_fn + + +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: list[Callable] +) -> None: + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + + +def _topological_sort_passes( + passes: list[Callable], constraints: list[Callable] +) -> list[Callable]: + """ + Args + passes: Passes that we are ordering + constraints: Constraints applied on these passes + + Returns + A sorted list of callables and a boolean of if a circular dependency + existed + """ + if len(constraints) == 0: + return passes + + # Contruct a graph mapping nodes to a list of their users + graph: dict[Callable, list[Callable]] = {p: [] for p in passes} + indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0) + candidates: Queue = Queue() + for a in passes: + for b in passes: + if a == b: + continue + + for constraint in constraints: + if not constraint(a, b): + graph[b].append(a) + indegree_map[a] += 1 + + if indegree_map[a] == 0: + candidates.put(a) + + visited: dict[Callable, bool] = dict.fromkeys(passes, False) + sorted_passes: list[Callable] = [] + + while not candidates.empty(): + p = candidates.get() + sorted_passes.append(p) + visited[p] = True + + for n in graph[p]: + if not visited[n]: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + # Check if there are unvisited nodes (aka cycles in the graph) + cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) + if len(cycle_passes) != 0: + error = ( + f"Circular dependency detected within the following passes: {cycle_passes}" + ) + raise RuntimeError(error) + + return sorted_passes + + +@compatibility(is_backward_compatible=False) +def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [pass_b, pass_a] + + constraints = [this_before_that_pass_constraint(pass_a, pass_b)] + ``` + + Args: + this (Callable): pass which should occur first + that (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + return a != that or b != this + + return depends_on + + +@compatibility(is_backward_compatible=False) +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): List of passes. A pass is a + callable which modifies an object and returns a PassResult + constraint (Optional[List[Callable]]): List of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + steps (int): Max number of times we run the passes (default = 1). + run_checks_after_each_pass (bool): Whether to run checks and linting + after each pass + suppress_check_failures (bool): Whether to raise errors when running + checks + """ + + passes: list[Callable[[nn.Module], PassResult]] + constraints: list[Callable[[Callable, Callable], bool]] + _validated: bool = False + steps: int = 1 + + def __init__( + self, + passes=None, + constraints=None, + steps=None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + ): + self.passes = passes or [] + self.constraints = constraints or [] + if steps: + self.steps = steps + + self.run_checks_after_each_pass = run_checks_after_each_pass + self.suppress_check_failures = suppress_check_failures + + def add_pass(self, _pass: Callable): + """ + Adds a pass into the current list of passes. + """ + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint: Callable): + """ + Adds a constraint into the current list of constraints. + """ + self.constraints.append(constraint) + self._validated = False + + def validate_constraints(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def solve_constraints(self): + """ + Finds a valid traversal order based on the given constraints and orders + the passes based on this order. + + If a circular dependency exists between the constraints and steps = 1, + then we will raise an error because if steps != 1 this means that we + will re-run the passes, allowing for circular dependencies. + """ + self.passes = _topological_sort_passes(self.passes, self.constraints) + self._validated = True + + def add_checks(self, check: Callable) -> None: + """ + Adds a function which takes runs various checks on a given graph module. + This function is run before and after each pass if the + `run_checks_after_each_pass` flag is enabled. + """ + sig = inspect.signature(check) + + if len(list(sig.parameters.values())) != 1: + raise TypeError( + "PassManager check function should only take in one variable, a module" + ) + + setattr(self, "check", check) # noqa: B010 + + def check(self, module: nn.Module) -> None: + pass + + def __call__(self, module: nn.Module) -> PassResult: + """ + Runs a list of passes in the order based on `self.passes` on the given + graph module. Each time a pass is run, checks and linting will be run on + the graph module if `run_checks_after_each_pass` is set. + + If the module is a graph module, we will run the list of passes until + the graph stops changing, or until `steps` number of times. + """ + # Order the passes based on the constraints + if not self._validated: + self.solve_constraints() + + # Check graph invariants + self.check(module) + + # Run the set of passes `steps` number of times or until the graph stops + # changing + overall_modified = False + for _ in range(self.steps): + modified = False + + # Run the set of passes on the graph module + for i, fn in enumerate(self.passes): + fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ + logger.debug("Running pass '%s'", fn_name) + + try: + res = fn(module) + + if not isinstance(res, PassResult) and not hasattr( + res, "graph_module" + ): + raise TypeError( + f"The result of the pass {fn_name} should be type PassResult." + + "Please wrap it with pass_result_wrapper()" + ) + module = res.graph_module + modified = modified or res.modified + + if isinstance(module, GraphModule): + logger.debug("Graph after pass '%s': %s", fn_name, module.graph) + module.recompile() + + # Check graph invariants + if self.run_checks_after_each_pass: + self.check(module) + + except Exception as e: + prev_pass_names = [ + p.__name__ if inspect.isfunction(p) else type(p).__name__ + for p in self.passes[:i] + ] + msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}" + raise Exception(msg) from e # noqa: TRY002 + + # If the graph no longer changes, then we can stop running these passes + overall_modified = overall_modified or modified + if not modified: + break + + return PassResult(module, overall_modified) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/net_min_base.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/net_min_base.py new file mode 100644 index 0000000000000000000000000000000000000000..548d2786feea77fc13afd1f37009af25781e1bbc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/net_min_base.py @@ -0,0 +1,978 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import Any, Callable, cast, Optional + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg + +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + Names, + NodeList, + NodeSet, + TensorOrTensors, + Tensors, +) + + +__all__ = [ + "FxNetMinimizerBadModuleError", + "FxNetMinimizerRunFuncError", + "FxNetMinimizerResultMismatchError", +] + +_LOGGER = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerBadModuleError(Exception): + """ + Raised if failed to split out a minimize module + """ + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerRunFuncError(Exception): + """ + Raised if error occurs during run_a or run_b functions + """ + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerResultMismatchError(Exception): + """ + Raised if comparing function thinks the results are mismatching. + """ + + +@dataclass +class _MinimizerSettingBase: + """ + Args: + `accumulate_error`: Instead of using a's input for both converted module to verify + , use the previous outputs of each converted module as input to accumulate the + errors. + + `traverse_method`: "sequential" or "binary" or "accumulate" + Determine the way of traverse the nodes in FX module. + + `find_all`: Minimizer will go through the entire model and return all problematic nodes. + + `return_intermediate`: If true, when using `run_nodes()` function to run the + model, intermediate results of all the ops will be returned as output. + + `all_outputs`: If true, when using `_run_and_compare()` function, + all the output nodes in the subgraph will be used for comparison. + """ + + accumulate_error: bool = False + traverse_method: str = "sequential" + find_all: bool = False + return_intermediate: bool = False + all_outputs: bool = False + + def __str__(self): + settings_str = "FX Minimizer Settings:\n" + + for k, v in vars(self).items(): + settings_str += f"\t{k}: {v}\n" + + return settings_str + + +class _MinimizerBase: + """ + This class is used to automatically find problematic nodes in a model. It takes a FX + graphmodule and generate some submodules while traverse the graph. Then two functions + `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn` + will be used to compare the results. + + Currently we provides two ways to traverse the graph and generate submodules. + 1. Sequential traversal: this will traverse the graph node by node and generate + one submodule with one sigle node. + 2. Binary searching: this will do a binary search style traversal on the graph. + + For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tensors, + compare_fn: Callable[ + [TensorOrTensors, TensorOrTensors, Names], tuple[float, bool] + ], + settings: _MinimizerSettingBase, + module_exporter: Optional[ + Callable[[Tensors, torch.fx.GraphModule, str], None] + ] = None, + exclusion_fn: Optional[Callable[[NodeList, int, int], None]] = None, + ): + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + self.sample_input = sample_input + self.compare_fn = compare_fn + self.module_exporter = module_exporter + self.settings = settings + self.exclusion_fn = exclusion_fn + + # Stores outputs of run_a function + self.a_outputs: dict[str, Any] = {} + + # Stores outputs of run_b function + self.b_outputs: dict[str, Any] = {} + + # Stores the results of compare_fn + self.results: dict[Any, Any] = {} + + # Stores the report for the runs + self.reports: list[list[str]] = [] + + # Current iteration + self.iteration: int = 0 + + callable_nodes = { + node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS + } + self.run_shape_prop() + self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)() + + # Check if number of input in sample_input matches the number of placeholders + placeholders = [ + node.name for node in self.module.graph.nodes if node.op == "placeholder" + ] + assert len(placeholders) == len(self.sample_input) + + # Store sample_input + for i, name in enumerate(placeholders): + self.a_outputs[name] = sample_input[i] + self.b_outputs[name] = sample_input[i] + + def run_shape_prop(self) -> None: + """ + Helper function to run shape propagation on module. Can be overridden by + subclasses for custom shape propagation logic. + """ + ShapeProp(self.module).propagate(*self.sample_input) + + def run_a( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_b(). + """ + raise RuntimeError("run_a() is not implemented.") + + def run_b( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_a(). + """ + raise RuntimeError("run_b() is not implemented.") + + def _store_outputs( + self, + a_result: TensorOrTensors, + b_result: TensorOrTensors, + submodule: torch.fx.GraphModule, + ): + """ + Store the outputs of self.run_a() and self.run_b() into self.a_outputs and + self.b_outputs, so that we can use them when execute preceding nodes that + use those outputs as inputs. + + Args: + a_result: Output of self.run_a(). Could be a tensor or tensors. + b_result: Output of self.run_b(). Could be a tensor or tensors. + submodule: The module that generates a_result and b_result. + """ + output_node = next( + node for node in submodule.graph.nodes if node.op == "output" + ) + + # Only one output + if isinstance(output_node.args[0], torch.fx.Node): + self.a_outputs[output_node.args[0].name] = a_result + self.b_outputs[output_node.args[0].name] = b_result + # Multiple outputs + else: + for i, arg in enumerate(output_node.args[0]): + self.a_outputs[arg.name] = a_result[i] + self.b_outputs[arg.name] = b_result[i] + + def _get_submod_inputs( + self, main_module: torch.fx.GraphModule, submod_path: str + ) -> tuple[Tensors, Tensors]: + """ + Try get submodule inputs from stored outputs. If not found then use + torch_glow.get_submod_inputs to get the inputs. + + If accumulate_error is False, use a_input for run_a() and run_b() + otherwise use a_input for run_a and b_input for run_b. + + Args: + main_module: Top-levlel fx module. + submod_path: Path to the submodule we want to run and compare results. + + Returns: + a_input: List of tensor(s) that will be used by run_a() as submodule inputs. + b_input: List of tensor(s) that will be used by run_b() as submodule inputs. + """ + a_input = [] + b_input = [] + submodule = getattr(main_module, submod_path) + placeholders = [ + node.name for node in submodule.graph.nodes if node.op == "placeholder" + ] + + # If all placeholder can be found in stored outputs, use stored + # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs` + # to get the inputs. + if set(placeholders) <= self.a_outputs.keys(): + for name in placeholders: + a_input.append(self.a_outputs[name]) + b_input.append(self.b_outputs[name]) + else: + if self.settings.accumulate_error: + print(f"Can't find previous stored outputs named {placeholders}!") + + def get_inputs(self: torch.nn.Module, inputs: Any): + nonlocal a_input + a_input = inputs + + # Use forward hook to get the inputs to the submodule + handle = submodule.register_forward_pre_hook(get_inputs) + main_module(*self.sample_input) + handle.remove() + + b_input = a_input + + if not self.settings.accumulate_error: + return a_input, a_input + + return a_input, b_input + + def _tag_nodes(self, selected_nodes: NodeSet): + """ + Tag selected nodes with tag "minimize". Nodes with the same tags will + be split to the same submodule afterwards. + + Args: + selected_nodes: Nodes that we want to minimize. We will tag those nodes + with "minimize", all preceding nodes with "main_0" and all following + nodes with "main_1". + """ + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node in selected_nodes: + node.tag = "minimize" + elif any( + n.tag in {"minimize", "main_1"} + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS + ): + node.tag = "main_1" + else: + node.tag = "main_0" + + def _build_submodule(self, nodes: NodeSet) -> tuple[torch.fx.GraphModule, str]: + """ + Split self.module so that one submodule consists of `nodes` and only `nodes`. + + Args: + nodes: Nodes that we want to include in the minimize submodule. + + Returns: + split_module (torch.fx.GraphModule): the module after split. + submodule_name (str): the name of the submodule that consists of `nodes`. + """ + # Color provided nodes + self._tag_nodes(nodes) + + # Split module based on coloring + split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"]) + + # Find submodule containing colored nodes + submodule_name: str = "" + for child_name, _ in split_module.named_children(): # type: ignore[union-attr] + # Skip submodules we're not interested in at the moment + if "minimize" not in child_name: + continue + + if submodule_name == "": + submodule_name = child_name + else: + raise FxNetMinimizerBadModuleError( + f"Expected only one minimize submodule with nodes {nodes}" + ) + + if submodule_name == "": + raise FxNetMinimizerBadModuleError( + f"Minimize submodule was not found with nodes {nodes}" + ) + + return split_module, submodule_name # type: ignore[return-value] + + def _run_and_compare( + self, + split_module: torch.fx.GraphModule, + submod_name: str, + output_names: Names, + report_idx: int = -1, + ): + """ + Run the submodule in `split_module` that has name `submod_name` + using `self.run_a` and `self.run_b` and compare their results. + + Args: + split_module: Main module that contains the minimize submodule. + submod_name: Name of the minimize submodule. + output_names: Names of the node we want to output. If None, we + will use the original output. + """ + submodule = getattr(split_module, submod_name) + a_input, b_input = self._get_submod_inputs(split_module, submod_name) + + if len(self.reports) == 0: + self.reports.append([]) + self.iteration = 1 + + report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1] + report.append("Run and compare ...") + + if output_names and not self.settings.all_outputs: + output_nodes: NodeList = [] + for node in submodule.graph.nodes: + if node.op == "output": + submodule.graph.erase_node(node) + + if node.name in output_names: + output_nodes.append(node) + + submodule.graph.output( + output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes) + ) + submodule.graph.lint() + submodule.recompile() + + # Use name of args in output node as key to store comparison result + for node in submodule.graph.nodes: + if node.op == "output": + result_key = map_arg(node.args, lambda x: x.name) + + try: + a_result = self.run_a(submodule, a_input, report_idx) + b_result = self.run_b(submodule, b_input, report_idx) + self._store_outputs(a_result, b_result, submodule) + except Exception as e: + report.append(f"Exception raised when running {submod_name}: {e}") + raise FxNetMinimizerRunFuncError( # noqa: B904 + f"Exception raised when running {submod_name}: {e}" + ) + + # Compare results + names: Names = output_names + if output_names is None: + names = [str(v) for v in result_key] # type: ignore[possibly-undefined] + + numeric_result, bool_result = self.compare_fn(a_result, b_result, names) + + self.results[result_key] = numeric_result # type: ignore[possibly-undefined] + report.append(f"Numerical accuracy = {numeric_result}") + if not bool_result: + report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] + if self.module_exporter: + if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + result_key = result_key[-1] + # If the result is still a tuple (happens in non-sequential mode), + # we only use the first element as name. + if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + result_key = str(result_key[0]) + # pyre-ignore[29]: not a function + self.module_exporter( + a_input, + submodule, + result_key + "_cpu", + ) + # pyre-ignore[29]: not a function + self.module_exporter( + b_input, + submodule, + result_key + "_acc", + ) + raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] + + def _binary_search_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: + """ + Recursive binary search implementation. + """ + culprits: NodeSet = set() + nodes: NodeList = all_nodes[start_idx:end_idx] + + report: list[str] = [] + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, start_idx, end_idx) + if len(nodes) == 0: + report = ["All nodes are excluded by user"] + self.reports.append(report) + return culprits + + first_node_name = nodes[0].name + output_node_name = nodes[-1].name + self.iteration += 1 + self.reports.append(report) + report.append(f"Binary search iteration {self.iteration}") + report.append( + f"From node index {start_idx}:{first_node_name} to {end_idx - 1}:{output_node_name}. " + f"Size of the interested node list is {len(nodes)}" + ) + cur_nodes: NodeSet = set(nodes) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [output_node_name]) + + except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): + if len(nodes) == 1: + report.append( + f"This is the last node in the sub-module. " + f"Search in the current branch is successful with culprit = {cur_nodes}." + ) + self.print_report(report) + return cur_nodes + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + mid = len(nodes) // 2 + culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid) + + if len(culprits) != 0 and not self.settings.find_all: + return culprits + + culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx) + + if len(culprits) == 0: + report.append( + f"Further split and lowering found no errors. " + f"Unable to minimize the submodule with list of nodes: {nodes}" + ) + self.print_report(report) + + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + def _binary_traverse(self, nodes: NodeList) -> NodeSet: + """ + Binary search on `nodes` for culprit. + """ + return self._binary_search_impl(nodes, 0, len(nodes)) + + def _sequential_traverse(self, nodes: NodeList) -> NodeSet: + """ + Traverse `nodes` one by one and determine if any of them is a culprit. + """ + culprits: NodeSet = set() + + for node in nodes: + report: list[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Sequential traverse iteration {self.iteration}.") + report.append(f"Visit node: {node.name}") + + _LOGGER.info("Visit node: %s", node.name) + node_list: NodeList = [node] + if self.exclusion_fn is not None: + self.exclusion_fn(node_list, -1, -1) + if len(node_list) == 0: + report.append(f"User exclusion : {node.name}") + self.print_report(report) + if not self.settings.find_all: + return culprits + else: + continue + + cur_nodes: NodeSet = {node} + + if node in self.fusions: + cur_nodes = self.fusions[node] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [node.name]) + self.print_report(report) + except FxNetMinimizerResultMismatchError: + culprits.add(node) + report.append(f"Found culprit from numeric error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + except FxNetMinimizerRunFuncError: + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + + return culprits + + def _block_traverse_impl( + self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool + ) -> Optional[int]: + """ + Recursive block search implementation. + find_last_node: If True, search for the last node which result in numerics difference + if False: find first node in sorted node list + """ + report: list[str] = [] + + mid = (start_idx + end_idx) // 2 + cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:] + + if self.exclusion_fn: + self.exclusion_fn(cur_nodes_list, -1, -1) + + cur_nodes = set(cur_nodes_list) + + first_node_name = cur_nodes_list[0].name + last_node_name = cur_nodes_list[-1].name + target_node_name = last_node_name if find_last_node else first_node_name + + self.iteration += 1 + self.reports.append(report) + report.extend( + [ + "=" * 30, + f"Block search iteration {self.iteration}", + ] + ) + report.extend( + [ + f"Search for {'last' if find_last_node else 'first'} node in culprits", + f"From node index {start_idx}:{nodes[start_idx].name} to {end_idx}:{nodes[end_idx].name}. ", + f"Subgraph constructed by {first_node_name} to {last_node_name}", + f"Targeting node: {target_node_name}", + f"Size of the interested node list is {end_idx - start_idx + 1}", + ] + ) + report_idx = len(self.reports) - 1 + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare( + split_module, submod_name, [last_node_name], report_idx + ) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + report.append( + f"Culprits found from node {first_node_name} to {last_node_name}." + ) + + if start_idx == mid == end_idx: + report.extend( + [ + "This is the last node in the sub-module. ", + "Search in the current branch is successful with node :", + f"{start_idx}, node name: {nodes[start_idx].name}.", + ] + ) + self.print_report(report) + return start_idx + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + if find_last_node: + return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) + else: + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) + else: + report.append( + f"Culprits not found from node start to {mid}:{nodes[mid].name}." + ) + + if start_idx == mid == end_idx: + # We did not find anything if the pointers have not moved + if (start_idx == 0 and not find_last_node) or ( + start_idx == len(nodes) - 1 and find_last_node + ): + report.append( + f"At {'last' if find_last_node else 'first'} node, no culprits found." + ) + self.print_report(report) + return None + + # Otherwise, we have converged on the border between discrepancy and valid + return start_idx + (1 if find_last_node else -1) + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + if find_last_node: + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) + else: + return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) + + def _block_traverse( + self, nodes: NodeList, find_last_node: Optional[bool] + ) -> NodeSet: + """ + Traverse topologically sorted node list + Find minimium block (start_idx, end_idx) which contains the culprit + 1st pass: search for end_idx by finding the last node in culprit block + where Numerical accuracy (0, end_idx) > threshold + 2nd pass: search for start_idx by finding the first node in culprit block + where Numerical accuracy (start_idx, end_idx) < threshold + Form minimum block by (start_idx - 1, end_idx) + """ + culprits: NodeSet = set() + first_node_name = nodes[0].name + last_node_name = nodes[-1].name + last_node_report = [f"Block search from {first_node_name} to {last_node_name}"] + last_node_report.append("*" * 50) + self.reports.append(last_node_report) + + start_idx = 0 + end_idx = len(nodes) - 1 + + final_start_idx: Optional[int] = start_idx + final_end_idx: Optional[int] = end_idx + + run_both = True if find_last_node is None else False + + # step 1: find (0, end_idx) of culprit block + if run_both or find_last_node: + last_node_report.append("Start searching for last node in culprit") + self.print_report(last_node_report) + final_end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True) + + if final_end_idx is None: + last_node_report.append("No culprits found") + self.print_report(last_node_report) + return culprits + + last_node_report.extend( + [ + "Finish Pass 1", + f"Find end_idx = {final_end_idx}:{nodes[final_end_idx].name}", + ] + ) + self.print_report(last_node_report) + + # step 2: reduce culprit block to (start_idx, end_idx) + if run_both or not find_last_node: + first_node_report = ["Start searching for first node in culprit"] + self.print_report(first_node_report) + final_start_idx = self._block_traverse_impl( + nodes[0 : end_idx + 1], start_idx, final_end_idx or end_idx, False + ) + + if final_start_idx is None: + last_node_report.append("No culprits found") + self.print_report(last_node_report) + return culprits + + first_node_report.append("*" * 50) + self.reports.append(first_node_report) + first_node_report.extend( + [ + "Finish Pass 2", + f"Find start_idx = {final_start_idx}:{nodes[final_start_idx].name}", + ] + ) + self.print_report(first_node_report) + + # step 3: form module with minimum culprits. These indexes are guaranteed to exist + range_start, range_end = cast(int, final_start_idx), cast(int, final_end_idx) + culprits.update(nodes[range_start : range_end + 1]) + result_report = [ + f"Finish searching, found minimum block ({nodes[range_start]},{nodes[range_end]})" + ] + self.reports.append(result_report) + self.print_report(result_report) + return culprits + + def _defined_traverse(self, nodes: NodeList) -> NodeSet: + """ + run user defined `nodes` and determine if it is a culprit. + """ + culprits: NodeSet = set() + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, -1, -1) + if len(nodes) == 0: + report = ["All nodes are excluded by user"] + self.reports.append(report) + return culprits + + first_node_name = nodes[0].name + output_node_name = nodes[-1].name + report = [f"Defined graph from {first_node_name} to {output_node_name}"] + cur_nodes: NodeSet = set(nodes) + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [output_node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + report.append(f"Found culprit {cur_nodes}") + self.print_report(report) + return culprits + + return culprits + + def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: + culprits: NodeSet = set() + nodes_to_run: NodeSet = set() + + # find_all is not supported for accumulate traversal because all the + # ops run on NNPI. So we return after the first op that raises error. + if self.settings.find_all: + print("'Find All' mode is not supported in accumulate traversal.") + return culprits + + for node in nodes: + report: list[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Accumulate traverse iteration {self.iteration}.") + + nodes_to_run.add(node) + + node_name = node.name + if node_name is not None and isinstance(node_name, tuple): + node_name = node_name[0] + assert node_name is not None and isinstance(node_name, str), ( + f"minimize: node_name: {node_name}" + ) + + report.append(f"Add node: {node_name}") + + try: + split_module, submod_name = self._build_submodule(nodes_to_run) + self._run_and_compare(split_module, submod_name, [node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + culprits.add(node) + report.append(f"Found culprit {node}") + self.print_report(report) + return culprits + + return culprits + + def _skip_traverse_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + culprits: NodeSet = set() + nodes: NodeList = all_nodes[start_idx:end_idx] + cur_nodes: NodeSet = set(nodes) + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, start_idx, end_idx) + cur_nodes = set(nodes) + else: + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + report: list[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f" Nodes block {self.iteration}.") + report.append( + f"From node index {start_idx} to {end_idx - 1}. " + f"Size of the interested node list is {len(nodes)}" + ) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, []) + except FxNetMinimizerResultMismatchError: + culprits.update(cur_nodes) + report.append(f"Found culprit from numeric error: {cur_nodes}") + self.print_report(report) + return culprits + except FxNetMinimizerRunFuncError: + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {cur_nodes}") + self.print_report(report) + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + start_idx = 0 + num_nodes = len(all_nodes) + idx = 0 + culprits = set() + while idx < num_nodes: + node = all_nodes[idx] + if node.name in skip_nodes: # skip the node + if idx > start_idx: + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx) + start_idx = idx + 1 + elif idx == num_nodes - 1 and start_idx <= idx: # last node + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1) + idx += 1 + + return culprits + + def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: + """ + Collect nodes in the model that between nodes with name of `start` and `end`. + These two nodes are also included. + """ + nodes: NodeList = [] + add_node = start is None + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node.name == start: + add_node = True + + if add_node: + nodes.append(node) + + if node.name == end: + break + + return nodes + + def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None): + """ + Run part of the model from `start` node to `end` node. If `start` is None + then we start from the beginning of the model. If `end` is None then we + stop at the end of the model. + + Args: + start: The name of the node which is the first node of the submodule + we want to run. If set to None, then we'll start with the first + node of the model. + end: The name of the node which is the last node of the submodule we + want to run. If set to None, we'll end with the last node of the + model. + """ + nodes = self._collect_nodes(start, end) + cur_nodes = set(nodes) + + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + + output_names = [] + if self.settings.return_intermediate: + output_names = [node.name for node in nodes] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, output_names) + except ( + FxNetMinimizerRunFuncError, + FxNetMinimizerResultMismatchError, + ) as e: + print(e) + + def print_report(self, report: list[str]): + for i in range(len(report)): + if i > 0: + print(" . " + report[i]) + else: + print(report[i]) + + def print_reports(self): + for report in self.reports: + self.print_report(report) + + def minimize( + self, + start: Optional[str] = None, + end: Optional[str] = None, + skip_nodes: Optional[list] = None, + find_last_node: Optional[bool] = None, + ) -> NodeSet: + """ + Minimizing the model from node with name `start` to node with name `end` base + on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors. + + Args: + start: The name of the node where we want to start minimizing. If set + to None, then we'll start with the first node of the model. + end: The name of the node where we want to terminate minimizing. If + set to None, we'll end with the last node of the model. + skip_nodes: The names of nodes where we want to skip during minimizing. + It'll create subgraphs without these skip nodes under the hood. + Only applicable in mode "skip". + find_last_node: True if only last_node of a culprits is needed in mode "block". + False if only the first_node of a culprits is needed. + Only applicable in mode "block". + + Returns: + nodes: A list of nodes that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors during minimizing. + """ + + print(self.settings) + print(self.module.graph) + + nodes = self._collect_nodes(start, end) + + if self.settings.traverse_method == "sequential": + return self._sequential_traverse(nodes) + + if self.settings.traverse_method == "binary": + return self._binary_traverse(nodes) + + if self.settings.traverse_method == "accumulate": + return self._accumulate_traverse(nodes) + + if self.settings.traverse_method == "skip": + if skip_nodes is None: + raise RuntimeError( + "'skip_nodes' can't be None when 'traverse_method' is 'skip'." + ) + return self._skip_traverse(nodes, skip_nodes) + + if self.settings.traverse_method == "defined": + return self._defined_traverse(nodes) + + if self.settings.traverse_method == "block": + return self._block_traverse(nodes, find_last_node) + + raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!") diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/operator_support.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/operator_support.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb14d312b60b0209195706488dd48a359c40b3f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/operator_support.py @@ -0,0 +1,229 @@ +# mypy: allow-untyped-defs +import abc +import typing as t + +import torch +import torch.fx +from torch.fx._compatibility import compatibility + +from .shape_prop import TensorMetadata +from .tools_common import CALLABLE_NODE_OPS, get_node_target + + +__all__ = [ + "OperatorSupportBase", + "OperatorSupport", + "create_op_support", + "chain", + "OpSupports", + "any_chain", +] + +# fx.Node.target typename, as returned by `get_node_target()` +TargetTypeName = str + +# Arguments' dtypes for a given node, see `OperatorSupport` +SupportedArgumentDTypes = t.Optional[ + tuple[ + t.Sequence[t.Sequence[torch.dtype]], + dict[str, t.Sequence[torch.dtype]], + ] +] + +SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] + + +@compatibility(is_backward_compatible=False) +class OperatorSupportBase(abc.ABC): + """Interface for determining if a fx.Node is supported by a backend""" + + @abc.abstractmethod + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + raise NotImplementedError + + +@compatibility(is_backward_compatible=False) +class OperatorSupport(OperatorSupportBase): + """ + `_support_dict` maps node.target typename to supported inputs dtypes. + + node.target typename is retrieved using helper function `get_node_target()` + + If supported inputs dtypes is None, it means any dtype is supported, else + we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). + + The first tuple ([dtypes], ...) indicates what dtypes are supported for + inputs in node.args and the second dict {"name": [dtypes], ...} indicates + what dtypes are supported for inputs in node.kwargs. + + For inputs in args, if we don't want to check it, we can put None there, + e.g. (None, [torch.float]) indicates that we don't care about the type of + the first input in args. And for inputs in kwargs, if not listed, will not + be checked. + """ + + _support_dict: SupportDict + + def __init__(self, support_dict: t.Optional[SupportDict] = None): + self._support_dict = support_dict or {} + + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + """ + Args: + `submodules`: mapping from module name to the module. This can be + retrieved by calling model.named_modules(). + + `node`: a Fx node that we want to determine whether it's supported. + + Returns: + `is_supported`: whether the arg `node` is supported. + """ + if node.op not in CALLABLE_NODE_OPS: + return True + + target = get_node_target(submodules, node) + + # Target not found in _support_dict meaning that we don't support this op at all + if target not in self._support_dict: + return False + + # The rule for target is None meaning that we accept any dtype + if self._support_dict[target] is None: + return True + + args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc] + + # Check args dtypes + for i, dtypes in enumerate(args_dtypes): + if len(node.args) <= i: + break + + # None indicates we don't care about the dtype of args[i] + if dtypes is None: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.args[i], torch.fx.Node): + continue + + arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type] + if arg_dtype not in dtypes: + return False + + # Check kwargs dtypes + for k, dtypes in kwargs_dtypes.items(): + if k not in node.kwargs: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.kwargs[k], torch.fx.Node): + continue + + kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type] + if kwarg_dtype not in dtypes: + return False + + return True + + +# ====================================================================== +# Functional interfaces and utils for defining basic operator support logic +# and composing them into more complex ones +# ====================================================================== + +IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool] + + +@compatibility(is_backward_compatible=False) +def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: + """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance + + `IsNodeSupported` has the same call signature as + `OperatorSupportBase.is_node_supported` + """ + + class FunctionalOperatorSupport(OperatorSupportBase): + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return is_node_supported(submodules, node) + + return FunctionalOperatorSupport() + + +@compatibility(is_backward_compatible=False) +def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns False if + any of it reports False. + """ + + def _chain(submods, node) -> bool: + return all(x.is_node_supported(submods, node) for x in op_support) + + return create_op_support(_chain) + + +@compatibility(is_backward_compatible=False) +def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns True if + any of it reports True. + """ + + def _any_chain(submods, node) -> bool: + return any(x.is_node_supported(submods, node) for x in op_support) + + return create_op_support(_any_chain) + + +@compatibility(is_backward_compatible=False) +class OpSupports: + """A set of atomic `OperatorSupportBase` instances that can be combined together + to form more complex operator support logic. + """ + + @classmethod + def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: + """Report a node as non-supported, if any of its arguments is of dtype""" + + def _decline_if_input_dtype( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + for arg in node.all_input_nodes: + arg_dtype = _get_arg_dtype(arg) + if arg_dtype == dtype: + return False + return True + + return create_op_support(_decline_if_input_dtype) + + @classmethod + def decline_if_node_in_names(cls, disallow_set: set[str]) -> OperatorSupportBase: + """ + If a node has a name that is in the disallow set, reported it as non-supported. + """ + + def _decline_if_node_in_names( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + return node.name not in disallow_set + + return create_op_support(_decline_if_node_in_names) + + +def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: + assert isinstance(arg, torch.fx.Node) + tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] + dtype = ( + tensor_meta.dtype + if isinstance(tensor_meta, TensorMetadata) + else arg.meta["type"] + ) + return dtype diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..02904b8e403e51a6cb00fae1dcdd4bbfbe2a66a6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/param_fetch.py @@ -0,0 +1,96 @@ +from typing import Any, Callable + +import torch +import torch.nn as nn +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +__all__ = [ + "default_matching", + "extract_attrs_for_lowering", + "lift_lowering_attrs_to_nodes", +] + + +# Matching method matches the attribute name of current version to the attribute name of `target_version` +@compatibility(is_backward_compatible=False) +def default_matching(name: str, target_version: int) -> str: + """Default matching method""" + return name + + +# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. +# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. +# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. +module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = { + torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), + torch.nn.modules.conv.Conv2d: ( + 1, + [ + "weight", + "bias", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + ], + default_matching, + ), + torch.nn.modules.batchnorm.BatchNorm2d: ( + 2, + ["weight", "bias", "running_mean", "running_var", "eps"], + default_matching, + ), + torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), + torch.nn.modules.pooling.MaxPool2d: ( + 1, + ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], + default_matching, + ), + torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), +} + + +@compatibility(is_backward_compatible=False) +def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]: + """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` + after checking module's version is compatible with the `module_fetch_book`. + """ + attrs_for_lowering: dict[str, Any] = {} + attrs_for_lowering["name"] = torch.typename(mod) + + if type(mod) in module_fetch_book: + version, param_to_fetch, matching_method = module_fetch_book[type(mod)] + if version < mod._version: + raise RuntimeError( + f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) + for attr in param_to_fetch: + attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) + else: + raise RuntimeError( + f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) + return attrs_for_lowering + + +@compatibility(is_backward_compatible=False) +def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.""" + submodules = dict(fx_module.named_modules()) + + for node in fx_module.graph.nodes: + if node.op == "call_module": + if isinstance(submodules[node.target], GraphModule): + lift_lowering_attrs_to_nodes(submodules[node.target]) + else: + node.attrs_for_lowering = extract_attrs_for_lowering( + submodules[node.target] + ) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/pass_manager.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..48dfe702fedbb5b4369872d7fcceb482111647b1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/pass_manager.py @@ -0,0 +1,253 @@ +# mypy: allow-untyped-defs +import logging +from functools import wraps +from inspect import unwrap +from typing import Callable, Optional + + +logger = logging.getLogger(__name__) + +__all__ = [ + "PassManager", + "inplace_wrapper", + "log_hook", + "loop_pass", + "this_before_that_pass_constraint", + "these_before_those_pass_constraint", +] + + +# for callables which modify object inplace and return something other than +# the object on which they act +def inplace_wrapper(fn: Callable) -> Callable: + """ + Convenience wrapper for passes which modify an object inplace. This + wrapper makes them return the modified object instead. + + Args: + fn (Callable[Object, Any]) + + Returns: + wrapped_fn (Callable[Object, Object]) + """ + + @wraps(fn) + def wrapped_fn(gm): + fn(gm) + return gm + + return wrapped_fn + + +def log_hook(fn: Callable, level=logging.INFO) -> Callable: + """ + Logs callable output. + + This is useful for logging output of passes. Note inplace_wrapper replaces + the pass output with the modified object. If we want to log the original + output, apply this wrapper before inplace_wrapper. + + + ``` + def my_pass(d: Dict) -> bool: + changed = False + if "foo" in d: + d["foo"] = "bar" + changed = True + return changed + + + pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))]) + ``` + + Args: + fn (Callable[Type1, Type2]) + level: logging level (e.g. logging.INFO) + + Returns: + wrapped_fn (Callable[Type1, Type2]) + """ + + @wraps(fn) + def wrapped_fn(gm): + val = fn(gm) + logger.log(level, "Ran pass %s\t Return value: %s", fn, val) + return val + + return wrapped_fn + + +def loop_pass( + base_pass: Callable, + n_iter: Optional[int] = None, + predicate: Optional[Callable] = None, +): + """ + Convenience wrapper for passes which need to be applied multiple times. + + Exactly one of `n_iter`or `predicate` must be specified. + + Args: + base_pass (Callable[Object, Object]): pass to be applied in loop + n_iter (int, optional): number of times to loop pass + predicate (Callable[Object, bool], optional): + + """ + assert (n_iter is not None) ^ (predicate is not None), ( + "Exactly one of `n_iter`or `predicate` must be specified." + ) + + @wraps(base_pass) + def new_pass(source): + output = source + if n_iter is not None and n_iter > 0: + for _ in range(n_iter): + output = base_pass(output) + elif predicate is not None: + while predicate(output): + output = base_pass(output) + else: + raise RuntimeError( + f"loop_pass must be given positive int n_iter (given " + f"{n_iter}) xor predicate (given {predicate})" + ) + return output + + return new_pass + + +# Pass Schedule Constraints: +# +# Implemented as 'depends on' operators. A constraint is satisfied iff a list +# has a valid partial ordering according to this comparison operator. +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: list[Callable] +): + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + + +def this_before_that_pass_constraint(this: Callable, that: Callable): + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + """ + + def depends_on(a: Callable, b: Callable): + return a != that or b != this + + return depends_on + + +def these_before_those_pass_constraint(these: Callable, those: Callable): + """ + Defines a partial order ('depends on' function) where `these` must occur + before `those`. Where the inputs are 'unwrapped' before comparison. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [ + loop_pass(pass_b, 3), + loop_pass(pass_a, 5), + ] + + constraints = [these_before_those_pass_constraint(pass_a, pass_b)] + ``` + + Args: + these (Callable): pass which should occur first + those (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + return unwrap(a) != those or unwrap(b) != these + + return depends_on + + +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): list of passes. A pass is a + callable which modifies an object and returns modified object + constraint (Optional[List[Callable]]): list of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + """ + + passes: list[Callable] + constraints: list[Callable] + _validated: bool = False + + def __init__( + self, + passes=None, + constraints=None, + ): + self.passes = passes or [] + self.constraints = constraints or [] + + @classmethod + def build_from_passlist(cls, passes): + pm = PassManager(passes) + # TODO(alexbeloi): add constraint management/validation + return pm + + def add_pass(self, _pass: Callable): + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint): + self.constraints.append(constraint) + self._validated = False + + def remove_pass(self, _passes: list[str]): + if _passes is None: + return + passes_left = [ps for ps in self.passes if ps.__name__ not in _passes] + self.passes = passes_left + self._validated = False + + def replace_pass(self, _target, _replacement): + passes_left = [] + for ps in self.passes: + if ps.__name__ == _target.__name__: + passes_left.append(_replacement) + else: + passes_left.append(ps) + self.passes = passes_left + self._validated = False + + def validate(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def __call__(self, source): + self.validate() + out = source + for _pass in self.passes: + out = _pass(out) + return out diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/reinplace.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..3edf2df612e7fef391c6d2eae618a6447a9446cd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/reinplace.py @@ -0,0 +1,754 @@ +# mypy: allow-untyped-defs +import _operator +import itertools +from collections import defaultdict +from enum import Enum +from typing import Any, Callable + +import torch +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map_only + + +__all__ = ["reinplace"] + + +class _ViewType(Enum): + NonView = 0 + SingleOutputView = 1 + MultiOutputView = 2 + + +def _is_view_op(tgt): + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return ( + first_arg.alias_info is not None and not first_arg.alias_info.is_write + ) + + +def _get_view_type(tgt) -> _ViewType: + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + if first_arg.alias_info is not None and not first_arg.alias_info.is_write: + # check if op is a multi-output view + if "*" in first_arg.alias_info.after_set: + return _ViewType.MultiOutputView + else: + return _ViewType.SingleOutputView + return _ViewType.NonView + + +# Stores a bunch of metadata related to functionalization each node. +# Relevant metadata: +# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) +# The fake tensor output from running the current node +# n.meta['view_of']: Node +# If the current node n is a view of some base tensor, the 'view_of' field tells us which +# view node was used to generate the current node (a view tensor). +# This information actually makes `fake_result` redundant, but we can use `fake_result` +# to sanity check that our aliasing information is correct. +@compatibility(is_backward_compatible=False) +class _FunctionalizationMetadataProp(torch.fx.Interpreter): + def run_node(self, node: Node): + self.node_counter += 1 + result = super().run_node(node) + node.meta["fake_result"] = result + node.meta["node_idx"] = self.node_counter + + # (1) Update metadata with the list of nodes that are used by this node + # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. + # We don't want to treat it as "being used as an input". + node_args = node.args + if node.target is torch.ops.aten.copy_.default: + node_args = node_args[1:] + + # (2) Update metadata to track aliasing information about view tensor nodes. + if node.op == "call_function": + view_type = _get_view_type(node.target) + if view_type == _ViewType.SingleOutputView: + assert isinstance(node.args[0], Node) + node.meta["view_of"] = node.args[0] + elif view_type == _ViewType.MultiOutputView: + self.multi_output_view_nodes[node] = node.args[0] + + # Check if we returned a multi-output view, + # and we're now grabbing the individual views from the output. + # + # For multi-output views, we want to map each output view to the base, + # but this mapping involves two separate nodes in FX IR. + # e.g. "a, b = x_1.split(...)" becomes: + # %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) + # %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) + # %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) + # And we'd like to set: + # getitem1.meta['view_of'] = x_1 + elif node.target is _operator.getitem: + list_arg = node.args[0] + maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) + if maybe_base_of_view is not None: + # Note: we could also track indexing info here for multi-output views. + # I don't think this metadata is strictly needed for de-functionalization. + assert isinstance(maybe_base_of_view, Node) + node.meta["view_of"] = maybe_base_of_view + + if "view_of" in node.meta: + # We're linking the current node with its first argument as views. + # Assert here that this is actually the case, and their storages are the same. + assert isinstance(node.meta["fake_result"], FakeTensor) + assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor) + view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage()) + base_storage = StorageWeakRef( + node.meta["view_of"].meta["fake_result"]._typed_storage() + ) + assert view_storage == base_storage + return result + + def propagate(self, *args): + self.multi_output_view_nodes = {} + self.node_counter = -1 + + with FakeTensorMode() as mode: + fake_args = [ + mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args + ] + return super().run(*fake_args) + + +def _schemas_match(functional_schema, inplace_schema): + names_match = ( + inplace_schema.name.endswith("_") + and inplace_schema.name[:-1] == functional_schema.name + ) + arg_types_match = len(functional_schema.arguments) == len( + inplace_schema.arguments + ) and all( + a1.type == a2.type + for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments) + ) + # for the inplace op, its first argument should be mutable + assert ( + inplace_schema.arguments[0].alias_info is not None + and inplace_schema.arguments[0].alias_info.is_write + ) + # and its remaining arguments shouldn't be. + assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) + return names_match and arg_types_match + + +# TODO: this should be beefed up to be able to properly re-inplace with: +# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) +# - out= ops (e.g. angle -> angle.out) +# TODO: we should also figure this info out using torchgen. +def _maybe_get_inplace_op(op): + # __module__ seems broken; it returns torch._ops.aten which doesn't exist + if not isinstance(op, torch._ops.OpOverload): + return None + # Some view ops have inplace variants (as_strided_, etc), + # but we do NOT want the reinplacing pass to directly add these into the program. + # (they'll require extra special handling, aren't aren't really useful for perf anyway) + if _is_view_op(op): + return None + op_namespace = op.__module__.split(".")[-1] + op_base_name = op.overloadpacket.__name__ + maybe_namespace_module = getattr(torch.ops, op_namespace) + maybe_inplace_op = ( + None + if maybe_namespace_module is None + else getattr(maybe_namespace_module, f"{op_base_name}_", None) + ) + if maybe_inplace_op is None: + return None + + inplace_overloads = [ + getattr(maybe_inplace_op, overload_name) + for overload_name in maybe_inplace_op.overloads() + ] + inplace_overloads_with_matching_schemas = [ + f for f in inplace_overloads if _schemas_match(op._schema, f._schema) + ] + # Just because foo() and foo_() are both existing operators, + # They aren't guaranteed to have compatible schemas. + # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, + # Even though several overloads of pow_ exist. + if len(inplace_overloads_with_matching_schemas) == 0: + return None + assert len(inplace_overloads_with_matching_schemas) == 1 + inplace_op = inplace_overloads_with_matching_schemas[0] + return inplace_op + + +_VIEW_INVERSE_MAP: dict[Callable[..., Any], Callable[..., Any]] = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} + + +# This function, given a set of set of (aliased) tensor nodes, +# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index +# in the node ordering. +def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int): + def _add_if_tensor(x, set_): + if isinstance(x, FakeTensor): + set_.add(StorageWeakRef(x._typed_storage())) + + nodes_used_after = set() + for t in tensor_aliases: + # get all nodes that use the current alias + usage_nodes = t.users + for n in usage_nodes: + # We only care about usages after the current node + if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index: + continue + # We also don't care about intermediate view ops. + # They only matter if their output is then used elsewhere + # (either in an out-of-place op, or as an output to the function). + if n in tensor_aliases: + if ( + isinstance(n.target, torch._ops.OpOverload) + or n.target == _operator.getitem + ): + continue + nodes_used_after.add(n) + return nodes_used_after + + +# Given an op that we're trying to re-inplace, "b = foo(a)", +# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" +# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: +# If there are any aliases in the alias_set(a) that satisfy: +# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" +# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata +# as "alias" +def _get_view_inverse_node_usages( + later_node_usages: set[Node], self_aliases: set[Node] +) -> set[Node]: + def matching_view_metadata(a, b): + return ( + a.size() == b.size() + and a.stride() == b.stride() + and a.storage_offset() == b.storage_offset() + ) + + view_inverse_nodes = set() + # Go through them in node order, so we can see chains of view_scatter ops. + for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]): + if n.target not in _VIEW_INVERSE_MAP: + continue + base = n.args[0] + mutated_view = n.args[1] + assert isinstance(base, Node) + assert isinstance(base.meta["fake_result"], FakeTensor) + assert isinstance(mutated_view, Node) + assert isinstance(mutated_view.meta["fake_result"], FakeTensor) + assert not isinstance(n.target, str) + # Check that this view_inverse op actually corresponds to taking doing the inverse + # of one of our existing self_alias nodes. + original_view = _VIEW_INVERSE_MAP[n.target] + for self_alias in self_aliases: + # We're looking for some alias of the self arg, "alias", + # that was created from some op `alias = foo(base, args...)` + # such that the current _scatter op "inverts" that foo call. + # We can check that by running the original op again, and checking that the strides match. + if "view_of" not in self_alias.meta: + continue + self_alias_base = self_alias.meta["view_of"] + try: + # The we're trying to re-use the args from the view_scatter call inside of the corresponding + # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse + # of the current alias we're looking at. + view_replay_metadata = original_view( + self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs + ) + expected_metadata = self_alias.meta["fake_result"] + # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. + if matching_view_metadata( + self_alias_base.meta["fake_result"], base.meta["fake_result"] + ) and matching_view_metadata(view_replay_metadata, expected_metadata): + view_inverse_nodes.add(n) + except Exception: + continue + + return view_inverse_nodes + + +@compatibility(is_backward_compatible=True) +def reinplace(gm, *sample_args): + """ + Given an fx.GraphModule, modifies it to perform "reinplacing", + mutating the nodes of the graph. + We look for out-of-place op call sites like `b = a.add(...)`, + and convert them to be inplace (`b = a.add_(...)`), + as long as the input to the current operator ("a") isn't re-used + anywhere later in the graph. + + This pass currently expects to operate on a **functional, ATen** graph. + This can be obtained by running `make_fx(functionalize(f))`. + + Sample inputs are needed to determine aliasing relationships of the inputs. + In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the + inputs to the program. + + Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: + + (1) Perform some initial checks on the metadata of "a" and "args..." + that can disqualify them from being reinplaced. + + (1a) Check that the self argument we're attempting to reinplace + has acceptable dtype/size metadata to reinplace with. + + For example, if we have: + a = torch.ones(1) + b = torch.ones(10) + out = torch.add(a, b) + We can't turn that into + a.add_(b) + Because that would require resizing "a". + + Similarly, we can't convert torch.ge(a, b) into a.ge_(b), + because that would require changing a's dtype (from e.g. float32 to bool). + Note that in this specific example, we could technically do better.. + + If we see the pattern: + a_1 = a.ge(b) + a_2 = aten._to_copy(a_1, a.dtype) + Then we this should be valid to completely re-inplace + (this is exactly what functionalization will emit when it sees a.ge_(b)). + + This optimization is only really important for user programs + that directly use inplace comparison ops though. + + We also cannot re-inplace on tensors that have overlapping memory, + e.g. torch.ones(1).expand(4, 4).add_(1) + + (1b) Check if "a" is an alias of any of the program inputs. + + If it is, skip and move to the next node. + Inplace'ing an op that would cause it to mutate a program is not sound, + because that would be a side effect visible to the user. + + NOTE: there's a future optimization that we should make: + if "a" is a (alias of a) program input, but later in the program + there is a node that looks like "a.copy_(...)", + Then re-inplacing is ok to do - we are temporarily re-using a's buffer, + which will later be overwritten by the copy_() call. + + This will be an important optimization to have for programs that mutate + their inputs. It currently isn't implemented though. + + (1c) Check if "a" and "args..." alias + + For example, re-inplacing to create code like the below + isn't guaranteed to be sound: + + aten.mul_(a, a) + + (2) Check that "a" and all of its outstanding aliases are not used anywhere + later in the graph. If this is the case, then it's safe to re-inplace + to "b = foo_(a)". + + There are a few caveats to this, explained in more detail below: + (a) If "a" is used later as an argument to a view op, that is okay. + It's only a problem if "a" (or that view) is later passed + into a normal operator, or if it is returned as the program output. + (b) If "a" is a repeat argument in `foo()`, then don't reinplace. + Most ATen kernels don't make any guarantees that this is sound, + e.g. if you do aten.mul_(a, a). + So we'll just ban re-inplacing in this case. + It's only a problem if "a" (or that view) is later passed + (c) If "a" is used as an input into a view "inverse" / "scatter" + operator, it is potentially fine to re-inplace + (and remove that scatter operator from the graph). + See below for a more detailed example. + + NOTE: there is an optimization in this step that is crucial + to fully recovering performance from functionalization. + + Given this program: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a) + torch.ops.aten.fill_(b, 0) + return d + + Functionalization will emit the following: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a, 0, 1) + b_updated = torch.ops.aten.fill(b, 0) + a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) + return a_updated + + Ordinarily, we would not be able to reinplace the fill, + because "b" aliases with "a" which is used by the diagonal_scatter call. + + "re-inplacing" is on the hook for figuring out that it is ok to + completely, the expensive diagonal_scatter call, if we re-inplace the add(). + + So, for every `alias in alias_set(a)`, instead of checking + that "alias" is not used anywhere later in the graph, + we check that + EITHER: + (a) alias is not used anywhere later in the graph + OR: + (b) alias is used exactly once later on in the graph, + in the following op: + + out = foo_scatter(alias, x, args...) + + where the following must hold: + (i) "foo_scatter" is the "inverse" operator for foo. + This only applies to "foo" ops that are view operators, + which view into a subset of the original tensor's memory. + In practice, there are ~4 operators where this applies: + diagonal -> diagonal_scatter + slice -> slice_scatter + select -> select_scatter + as_strided -> as_strided_scatter + (ii) "args..." are the same between the foo() and foo_scatter() calls. + + (3) Perform the actual re-inplacing on foo! + + (3b) is the common case, but special care is needed for {view}_scatter (3a) + + (3a) {view}_scatter ops. + + Consider this program: + a = torch.zeros(2, 2) + b = torch.ones(2) + a[0] = b + + Post functionalization, that will look like: + a = torch.zeros(2) + b = torch.ones(1) + a_updated = torch.select_scatter(a, b, 0, 0) + + In this case though, there is no "functional" op to re-inplace! + Instead, we'd like to directly remove toe select_scatter call. + We already know from (3) that this is valid, + because "a" has no later usages in the graph. + + We perform the re-inplacing on the {view}_scatter op like so + Before: + a_updated = torch.select_scatter(a, b, args...) + After: + a_slice = a.select(a, args...) + a_slice.copy_(b) + + (3b) Otherwise, replace the functional op with its inplace variant. + Before: + b = foo(a, args...) + After: + a.foo_(args...) + + (4) Finally, after converting either: + Before: + b = foo(a) + After: + foo_(a) + or + Before: + b = {slice}_scatter(a, mutated_slice, args...) + After: + slice = {slice}(a, args...) + slice.copy_(mutated_slice) + + We now need to find all later nodes that use "b" as an argument + and update them to take in "a" instead. + + Note that for the majority of inplace ops, this isn't actually necessary + (because most inplace ops return "self" as their output). + This isn't generally true for all mutable ops though, which is why + we need to actually replace all of the arguments. + + We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], + That maps a given tensor storage to the set of all nodes that take in that storage + as an input. + Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused + together. + + (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" + during step (3) get manually deleted from the graph. + Their outputs are no longer used, so technically standard DCE would be able + to do this, but we can no longer run FX's DCE pass now that we have mutable + ops in the graph. + """ + _FunctionalizationMetadataProp(gm).propagate(*sample_args) + + # Useful debug printing + # def _print(x): + # if isinstance(x, FakeTensor): + # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') + + # for n in gm.graph.nodes: + # print(n.format_node()) + # if hasattr(n, 'meta'): + # print(f'node_idx: {n.meta["node_idx"]}') + # if 'fake_result' in n.meta: + # tree_map(_print, n.meta['fake_result']) + # if 'view_of' in n.meta: + # print(f'view_of: {str(n.meta["view_of"])}') + # print() + + # We need to know which nodes correspond to inputs (or their aliases) + # so we know not to re-inplace them. + # NOTE: later, we'll need to add an optimization for fully recovering performance + # on programs that mutate inputs. + input_storages = { + StorageWeakRef(node.meta["fake_result"]._typed_storage()) + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.meta["fake_result"], torch.Tensor) + ) + } + + # We also need to know for a given node, what are all of its aliasing nodes. + storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set) + for n in gm.graph.nodes: + if "fake_result" in n.meta: + # Tree-mapping because some ops can return lists of tensors. + def _add_to_map(x): + if isinstance(x, FakeTensor): + storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) + + pytree.tree_map_(_add_to_map, n.meta["fake_result"]) + + # inplace-ify functional ops, subject to the constraints written below. + all_later_view_inverse_nodes_to_delete = set() + for node in gm.graph.nodes: + if node.op == "call_function": + # Today, the re-inplace pass on directly acts on: + # - functional ops with an inplace variant + # - {view}_scatter ops that can be potentially removed from the graph. + # Both of these ops take in tensor first args, so filtering on this condition + # makes the later code simpler. + # We should revisit this at some point though, particularly when we also want + # the reinplacer to be able to handle out= and mutable operators + # and tensorlist first args (like `_foreach_` ops). + if not isinstance(node.target, torch._ops.OpOverload): + continue + if len(node.target._schema.arguments) < 1: + continue + if type(node.target._schema.arguments[0].type) != torch.TensorType: + continue + + # Step 1a: Check that the self argument we're attempting to reinplace + # has the same size/stride as the output. + # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) + # As it would require resizing scalar_tensor. + # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), + # this is probably an optimization to revisit later). + self_arg = node.args[0] + self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"]) + node_flattened = pytree.tree_leaves(node.meta["fake_result"]) + self_has_wrong_metadata = False + if len(self_flattened) == len(node_flattened): + for self_meta, node_meta in zip(self_flattened, node_flattened): + if self_meta.numel() != node_meta.numel(): + self_has_wrong_metadata = True + if self_meta.dtype != node_meta.dtype: + self_has_wrong_metadata = True + # We also cannot re-inplace on tensors that have internal memory overlap. + # e.g. torch.ones(1).expand(4, 4).add_(1) + if torch._debug_has_internal_overlap(self_meta) == 1: + self_has_wrong_metadata = True + # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, + # Since users should never really be calling the functional "torch.ops.aten.resize" + # op directly in their programs. + if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: + continue + + # Step 1b: ensure that the op we're trying to re-inplace isn't a program input + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) + if self_arg_storage in input_storages: + # TODO: later, add the optimization for handling `copy_()` calls in the graph. + continue + if len([x for x in node.args if x is self_arg]) > 1: + # Step 1c: + # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, + # so we prevent re-inplacing in this case. + continue + + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) + self_aliases = storage_to_nodes[self_arg_storage] + + # First, we find all later usages of any of the aliases of self_arg. + later_node_usages = _get_all_later_node_usages( + self_aliases, node.meta["node_idx"] + ) + # Then, we check if any of those later usages are actually view_scatter ops + # that are safe to fully remove. + later_view_inverse_node_usages = _get_view_inverse_node_usages( + later_node_usages, self_aliases + ) + + # Step 2: Check to see if the input to the op is re-used later in the graph. + # If not (same goes for its aliases), then this op is safe to re-in place. + # This is a slightly roundabout way to check that there are no later usages of the current self argument. + # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) + can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 + if not can_reinplace: + continue + + # Step 3a: Special handling for when we see *_scatter operators. + # When we see an operator like `b = torch.slice_scatter(a, ...)`, + # instead of trying to "inplace" it into a.slice_scatter_(..._), + # we would prefer to remove it from the graph entirely, + # and instead copy_() the slice directly into the larger tensor. + # See the description of the algorithm for a full example. + if ( + node.target in _VIEW_INVERSE_MAP + and node not in all_later_view_inverse_nodes_to_delete + ): + view_op = _VIEW_INVERSE_MAP[node.target] + # Before: + # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) + # After: + # slice = torch.ops.aten.slice.default(base, args...) + # slice.copy_(mutated_slice) + with gm.graph.inserting_before(node): + mutated_slice_node = node.args[1] + remaining_slice_args = node.args[2:] + slice_node = gm.graph.create_node( + "call_function", + view_op, + (self_arg,) + tuple(remaining_slice_args), + node.kwargs, + ) + gm.graph.create_node( + "call_function", + torch.ops.aten.copy_.default, + ( + slice_node, + mutated_slice_node, + ), + {}, + ) + # Add the slice_scatter node to our "nodes to delete" list. + all_later_view_inverse_nodes_to_delete.add(node) + + else: + # Step 3b: Check to see if this operator has an inplace variant. + maybe_inplace_op = _maybe_get_inplace_op(node.target) + if maybe_inplace_op is None: + continue + # And if so, replace it with its inplace variant. + node.target = maybe_inplace_op + + # At this point, 'storage_to_nodes' will be stale. + # Now that we're inplacing `b = foo(a)`, we need to effectively + # union together the dict values for b and a's storage. + # Hmm... morally I think we also want to keep the `fake_result` metadata + # up to date here, but I'm not sure how easy it is to do. + # Maybe it's fine to wait until the end of the pass to update it. + curr_node_storage = StorageWeakRef( + node.meta["fake_result"]._typed_storage() + ) + storage_to_nodes[self_arg_storage].update( + storage_to_nodes[curr_node_storage] + ) + storage_to_nodes[curr_node_storage].update( + storage_to_nodes[self_arg_storage] + ) + + # Need to remember the view_scatter view nodes we found so we can remove them alter. + all_later_view_inverse_nodes_to_delete.update( + later_view_inverse_node_usages + ) + + # Step 4: + # Now that we've replaced b = a.foo() with a.foo_(), + # We need to replace any later usages of "b" with "a" + for old in itertools.chain([node], later_view_inverse_node_usages): + new = old.args[0] + nodes_to_update = [ + n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"] + ] + for node_to_update in nodes_to_update: + + def replace_arg(a): + if a == old: + return new + return a + + # First, replace usages of "b" with "a" + node_to_update.args = tree_map_only( + Node, replace_arg, node_to_update.args + ) + node_to_update.kwargs = tree_map_only( + Node, replace_arg, node_to_update.kwargs + ) + + # Second, update our storage_to_nodes data structure. + old_flattened_res = pytree.tree_leaves(old.meta["fake_result"]) + node_flattened_res = pytree.tree_leaves( + node_to_update.meta["fake_result"] + ) + + old_res_storage = { + StorageWeakRef(x._typed_storage()) + for x in old_flattened_res + if isinstance(x, FakeTensor) + } + node_res_storage = { + StorageWeakRef(x._typed_storage()) + for x in node_flattened_res + if isinstance(x, FakeTensor) + } + + # This will happen if we're updating a view op, e.g. + # e.g. replacing + # x = view(old) + # x = view(new) + # When that happens, we need to make sure to keep our + # storage mapping up to date. + # + # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, + # or multiple tensors that all share the same storage. + # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. + if ( + len(old_res_storage) == 1 + and len(node_res_storage) == 1 + and old_res_storage == node_res_storage + ): + new_flattened_res = pytree.tree_leaves(new.meta["fake_result"]) + new_res_storage = { + StorageWeakRef(x._typed_storage()) + for x in new_flattened_res + if isinstance(x, FakeTensor) + } + assert len(new_res_storage) == 1 + (new_ref,) = new_res_storage + (node_ref,) = node_res_storage + # Technically, "old_ref" and all its aliases will remain + # in our mapping. + # That should be fine though, since we deleted "old" + # from the graph at this point. + storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) + storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) + + # Step 4: delete any _scatter nodes that we de-functionalized + # Need to take care not to delete any of these nodes until after *all* modifications + # to the graph are finished. + for to_delete in all_later_view_inverse_nodes_to_delete: + gm.graph.erase_node(to_delete) + + gm.recompile() + return gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..38c64c527aff066f13a2d32ec69c716995f2f40c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/runtime_assert.py @@ -0,0 +1,633 @@ +# mypy: allow-untyped-defs +import functools +import logging +import operator +import sys +from typing import Any, Optional, TYPE_CHECKING + + +# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow +if TYPE_CHECKING: + import sympy + + from torch.fx.experimental.symbolic_shapes import ShapeEnv +else: + ShapeEnv = Any + +import torch +import torch.utils._pytree as pytree +from torch import fx +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx._compatibility import compatibility +from torch.fx._utils import lazy_format_graph_code +from torch.fx.experimental.proxy_tensor import py_sym_types +from torch.fx.experimental.sym_node import SymNode +from torch.fx.graph_module import GraphModule + + +__all__ = ["insert_deferred_runtime_asserts"] + +log = logging.getLogger(__name__) +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose") + + +def _get_example_value(node: fx.Node) -> Optional[str]: + """ + Get the example value key for a node, since dynamo uses "example_value" + while non-strict export uses "val. + """ + if "example_value" in node.meta: + return node.meta["example_value"] + elif "val" in node.meta: + return node.meta["val"] + else: + return None + + +def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]: + val = _get_example_value(node) + if isinstance(val, py_sym_types): + return val.node.expr + return None + + +@compatibility(is_backward_compatible=True) +def insert_deferred_runtime_asserts( + gm: GraphModule, + shape_env: ShapeEnv, + name: str, + export: bool = False, +) -> None: + """ + During tracing, we may have discovered that some data-dependent values + had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime + that x.item() >= 0. This asserts can happen unpredictably during fake + tensor propagation, so we cannot conveniently insert them into the FX graph + when they occur. Instead, we accumulate them in the ShapeEnv, and in this + pass insert them into the graph as proper tests. + + This pass also deduplicates size-related computation, CSE-ing ops that produce + symbolic values and/or are involved in runtime asserts. Additionally, shape calls + (size/stride/storage_offset) are turned into compute on input sizes if possible, + allowing intermediate tensors to be freed earlier. For example, here dynamo will + DCE the cat and repeat calls: + + z = torch.cat([x, x], dim=0) # 2*s0 + w = z.repeat(y.shape[0]) # 2*s0*s1 + _w = w.shape[0] + # something with _w, but not w ... + + # turns into -> + _w0 = 2 * s0 + _w = _w0 * s1 + + # where s0, s1 are either SymInt graph inputs, or the result of added size calls + + Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert + the same expression, and redundant constrain_range calls are also deduplicated. + Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate + information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol, + and we delete all previous calls, adding bound checks at the end of this pass. + """ + + # Import sympy locally + import sympy + + from torch._export.passes._node_metadata_hook import _set_node_metadata_hook + from torch.fx.experimental.symbolic_shapes import ( + _get_placeholder_expr, + _has_uninterpretable_sympy_function, + CallMethodKey, + cast_symbool_to_symint_guardless, + ConvertIntKey, + DivideByKey, + free_symbols, + InnerTensorKey, + resolve_unbacked_bindings, + ) + from torch.utils._sympy.numbers import int_oo + from torch.utils._sympy.reference import ( + OptimizedPythonReferenceAnalysis, + PythonReferenceAnalysis, + ) + from torch.utils._sympy.value_ranges import ValueRanges + + # TODO: Request simplification on runtime asserts before emitting them + ras_by_symbol = shape_env.deferred_runtime_asserts.copy() + graph = gm.graph + tracer = fx.proxy.GraphAppendingTracer(graph) + graph_code_log.debug( + "%s", + lazy_format_graph_code( + f"pre insert_deferred_runtime_asserts {name}", gm, colored=True + ), + ) + + # We are going to mutate the dict + expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {} + placeholders = set() + first_non_placeholder = None + for node in graph.nodes: + if node.op != "placeholder": + first_non_placeholder = node + break + else: + placeholders.add(node) + + def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: + """ + If a size/stride/storage offset call on an intermediate tensor, + we can try to compute the value from input shapes instead. + """ + return ( + (val := _get_sym_val(node)) is not None + and not isinstance(val, sympy.Number) + # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported + and not _has_uninterpretable_sympy_function(val) + and any( + isinstance(arg, fx.Node) + and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) + and arg.op != "placeholder" + for arg in node.args + ) + ) + + # Figure out what key to use, val or example_value + val_key = "val" + for node in graph.nodes: + if "example_value" in node.meta: + val_key = "example_value" + break + elif "val" in node.meta: + break + + def _node_metadata_hook( + node: torch.fx.Node, + stack_trace: Optional[str] = None, + nn_module_stack: Optional[dict[str, Any]] = None, + ) -> None: + fake_args = pytree.tree_map( + lambda arg: ( + _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg + ), + node.args, + ) + try: + target = node.target + if node.op == "call_method": + assert isinstance(node.target, str) + target = getattr(fake_args[0], node.target) + fake_args = fake_args[1:] + node.meta[val_key] = target(*fake_args) # type: ignore[operator] + except NotImplementedError: + # This can happen when attempting to reify a symbol with an unsupported call_function node, + # e.g. with NestedTensors + sym_size.int via match_symbol(). + # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input. + pass + if stack_trace is not None: + node.meta["stack_trace"] = stack_trace + if nn_module_stack is not None: + node.meta["nn_module_stack"] = nn_module_stack + + # Track asserts/checks we've added + added_asserts: set[sympy.Expr] = set() + constrained_unbacked_symbols: set[sympy.Symbol] = set() + + Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis + + def _sympy_interp(expr_to_proxy, expr): + # sympy_interp() with hash consing + from sympy import Integer, Number, Symbol + from sympy.logic.boolalg import BooleanAtom + + from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp + + # hash cons + if expr in expr_to_proxy: + return expr_to_proxy[expr] + # base cases, don't cache + if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)): + return sympy_interp(Analysis, expr_to_proxy, expr) + + # hash cons on arguments, run expr handler + expr_to_proxy[expr] = _run_sympy_handler( + Analysis, + [_sympy_interp(expr_to_proxy, arg) for arg in expr.args], + expr, + ) + return expr_to_proxy[expr] + + def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool: + # This is probably unnecessary, but since torch._check() calls for single-symbol bounds + # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant + # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial. + if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan): + return False + lhs, rhs = expr.args + return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or ( + isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number) + ) + + def add_runtime_asserts(ras): + for ra in ras: + if ( + # redundant + ra.expr in added_asserts + # if we've already added a constrain_range call for this symbol, + # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant. + or ( + len(ra.expr.free_symbols) == 1 + and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols + and _is_bound_expr_for_symbol(ra.expr) + ) + # don't try to reify sympy functions we can't turn into FX nodes + or _has_uninterpretable_sympy_function(ra.expr) + ): + continue + + log.debug("inserting runtime assert %s", ra.expr) + # Need to process ALL free symbols, not just unbacked ones + fvs = free_symbols(ra.expr) + missing = fvs - expr_to_proxy.keys() + if missing: + i1 = min(missing, key=str) + # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689 + # assert shape_env.is_unbacked_symint(i1), i1 + ras_by_symbol.setdefault(i1, []).append(ra) + else: + # Convert the sympy expression into a sequence of FX + # nodes + with _set_node_metadata_hook(gm, _node_metadata_hook): + res = _sympy_interp(expr_to_proxy, ra.expr).node + + graph.call_function( + torch.ops.aten._assert_scalar.default, + # TODO: use ra.msg here, but it's pretty + # useless right now + ( + res, + f"Runtime assertion failed for expression {ra.expr} on node '{res}'", + ), + ) + added_asserts.add(ra.expr) + + nodes = list(graph.nodes) + for i, node in enumerate(nodes[:-1]): + # Placeholders can match symbols, but when we destructure them + # with size we have to make sure we insert the nodes after all + # the placeholders + with graph.inserting_before( + nodes[i + 1] if node not in placeholders else first_non_placeholder + ): + # Unfortunately, this logic still must remain because manual + # make_fx calls may not explicitly bind all symbolic ints as + # arguments to the function, so we must infer it from the other + # arguments + if ( + node in placeholders + and (example_value := _get_example_value(node)) is not None + ): + + def match_symbol(symint, cb): + if ( + isinstance(symint, torch.SymInt) + and isinstance(symint.node, SymNode) + and isinstance( + s := _get_placeholder_expr(symint.node), sympy.Symbol + ) + and s not in expr_to_proxy + ): + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer) + log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) + + match_symbol(example_value, lambda: node) + if isinstance(t := example_value, torch.Tensor): + for i, s in enumerate(t.size()): + match_symbol( + s, + lambda: graph.call_function( + torch.ops.aten.sym_size.int, (node, i) + ), + ) + if not is_sparse_any(t): + for i, s in enumerate(t.stride()): + match_symbol( + s, + lambda: graph.call_function( + torch.ops.aten.sym_stride.int, (node, i) + ), + ) + match_symbol( + t.storage_offset(), + lambda: graph.call_function( + torch.ops.aten.sym_storage_offset.default, (node,) + ), + ) + + # Handle asserts that aren't associated with any symbol. This + # doesn't really have to be in the loop as it will only run once, + # it just needs to happen right after the placeholders. + # insert this after placeholders & added sym nodes, and before non-placeholders. + if node == first_non_placeholder: + add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] + + # deduplicate asserts already present in graph, and remove trivial asserts + if node.target in ( + torch._check, + torch.ops.aten._assert_scalar.default, + ): + if ( + node.args[0] == True # noqa: E712 + or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy + and assert_expr in added_asserts + ): + arg = node.args[0] + gm.graph.erase_node(node) + if isinstance(arg, fx.Node) and not arg.users: + gm.graph.erase_node(arg) + else: + added_asserts.add(assert_expr) # type: ignore[arg-type] + + # hash cons, replace function calls that return torch.SymInts with direct references to + # FX nodes built up to reify the sympy expression. + if ( + node.op != "placeholder" + and (sym_expr := _get_sym_val(node)) is not None + ): + # this guards against deleting calls like item() that produce new untracked symbols + def has_new_untracked_symbols(): + for symbol in sym_expr.free_symbols: + if symbol not in expr_to_proxy: + return True + return False + + # this guards against deleting calls that produce unbacked bindings we haven't yet seen. + # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint + # (is backed), but produces an unbacked symbol. In this case keep the node alive. + resolved_unbacked_bindings = resolve_unbacked_bindings( + shape_env, node.meta.get("unbacked_bindings", {}) + ) + + assert resolved_unbacked_bindings is not None + + def has_new_unbacked_bindings(): + for key in resolved_unbacked_bindings.keys(): + if key not in expr_to_proxy: + return True + return False + + # maybe re-reify expression, replace current node + if ( + sym_expr in expr_to_proxy + or ( # example value is redundant + _is_intermediate_tensor_sym_call(node) + # shape call on intermediate tensor, turn into computation on input shapes + and not has_new_untracked_symbols() + ) + ) and not has_new_unbacked_bindings(): + if _is_intermediate_tensor_sym_call( + node + ): # reify from input shapes + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + expr_to_proxy[sym_expr] = _sympy_interp( + expr_to_proxy, sym_expr + ) # type: ignore[arg-type] + # won't try DCE-ing tensor compute here + hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] + node.replace_all_uses_with(hash_node) + gm.graph.erase_node(node) + log.debug( + "CSE node %s -> %s for expr %s", node, hash_node, sym_expr + ) + + # store node in hash cons, don't delete/replace + elif sym_expr not in expr_to_proxy and not isinstance( + sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) + ): # don't hash cons primitives + expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type] + + # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained, + # so calls before that are redundant. + if node.target in ( + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten.sym_constrain_range_for_size.default, + ): + gm.graph.erase_node(node) + + defs = [] + + # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as + # equivalent, but the refinement calls we perform in this pass may struggle with associating the two. + # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough + # information about the old symbol when we re-export, raising errors on data-dependent guards. + # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is. + if unbacked_bindings := resolve_unbacked_bindings( + shape_env, node.meta.get("unbacked_bindings") + ): + for s, keypath in unbacked_bindings.items(): + defs.append(s) + + # TODO: some CSE when generating these nodes can probably + # help reduce graph size and improve compile time + def go(node, keypath): + if keypath == (): + return node + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + if keypath[0].name == "size": + return go( + graph.call_function( + torch.ops.aten.sym_size.int, + (node, keypath[1].idx), + ), + keypath[2:], + ) + if keypath[0].name == "stride": + return go( + graph.call_function( + torch.ops.aten.sym_stride.int, + (node, keypath[1].idx), + ), + keypath[2:], + ) + return go( + graph.call_method( + keypath[0].name, (node, keypath[1].idx) + ), + keypath[2:], + ) + elif isinstance(keypath[0], CallMethodKey): + return go( + graph.call_method(keypath[0].name, (node,)), keypath[1:] + ) + elif isinstance(keypath[0], pytree.SequenceKey): + return go( + graph.call_function( + operator.getitem, (node, keypath[0].idx) + ), + keypath[1:], + ) + elif isinstance(keypath[0], ConvertIntKey): + return go( + graph.call_function( + cast_symbool_to_symint_guardless, (node,) + ), + keypath[1:], + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + return go( + graph.call_function( + operator.floordiv, (node, keypath[0].divisor) + ), + keypath[1:], + ) + elif isinstance(keypath[0], InnerTensorKey): + return go( + graph.call_function( + getattr, (node, keypath[0].inner_name) + ), + keypath[1:], + ) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + if s not in expr_to_proxy: + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy( + go(node, keypath), tracer=tracer + ) + log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) + + for i0 in defs: + ras = ras_by_symbol.pop(i0, []) + # Before we perform any asserts, first apply range + # refinement. This is important, because if we are going + # to retrace the graph (and we typically are if we send + # the graph to AOTAutograd), we need to make sure we apply + # range refinement (ala _check_is_size) first, BEFORE we + # run any of the asserts. Otherwise, we may decide to + # perform substitutions based on the asserts which we then + # can't back out, because value ranges can only be applied + # to asserts.) + # + # A perhaps better long term plan is to avoid this order + # dependence by making it possible to refine ranges on + # arbitrary expressions, not just symbols. But it is not + # so easy to make use of this information, see + # https://twitter.com/ezyang/status/1745801370299482492 + # We actually made an attempt at this in + # https://github.com/pytorch/pytorch/pull/119043 + # which didn't work. + # + # Another ideas for how to do this: + # - Have bound_sympy be the source of truth of the ranges of any expression + # - Cache intermediate results for every subexpression of bound_sympy + # - This cache should be possible to edit to refine ranges + # + # One issue with this proposal is that if + # we have a bound on 2x, we are not going to be able to + # apply it for 4x. Similarly, we may have bounds for an + # equivalent expression that we are not applying because + # it's not a perfect match (e.g. x < y vs y > x)". + # + # The first issue we already have it and it's impossible + # to solve in general, so any implementation on a best + # effort basis should do. + # + # The second issue is a preexisting one. It can be mitigated + # with a normalization algorithm. In general, it may also + # be on a best effort basis, but since our grammar is not + # terribly difficult, chances are we could even fully + # normalize SymPy expressions... who knows. + if i0 in constrained_unbacked_symbols: + continue # constrain symbol just once + + if i0 in shape_env.size_like: + if export: + graph.call_function( + torch.ops.aten.sym_constrain_range_for_size.default, + (expr_to_proxy[i0].node,), + ) + else: + graph.call_function( + torch._check_is_size, (expr_to_proxy[i0].node,) + ) + + vr = shape_env.var_to_range[i0] + if vr.is_int and vr.upper == sys.maxsize - 1: + # treat upper bound == sys.maxsize - 1 for int symbols as +oo + # to avoid redundant runtime assert + vr = ValueRanges(vr.lower, int_oo) + if not shape_env._default_unspecified_value_range().issubset(vr): + # The runtime range is constrained, so add a runtime + # assert and also explicitly refine the range + # (refinement should not be necessary once runtime + # asserts cause refinement, but that's NYI) + def convert(s): + if s in (int_oo, -int_oo): + return None + try: + return int(s) + except TypeError: + return None + + if ( + expr_to_proxy[i0].node.target + != cast_symbool_to_symint_guardless + ): + # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts + # raises AOTAutograd errors on cast_symbool_to_symint_guardless + + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) + + constrained_unbacked_symbols.add(i0) + add_runtime_asserts(ras) + + # delete unused reified symbols + for expr, proxy in expr_to_proxy.items(): + if ( + isinstance(expr, sympy.Symbol) + and proxy.node.op != "placeholder" # keep placeholders intact + and not proxy.node.users + ): + log.debug("deleting unused reified symbol for %s", expr) + gm.graph.erase_node(proxy.node) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..05fb3b5dbaf606d19b433a4959b8cd680f863fc6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/shape_prop.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +import traceback +from typing import Any, NamedTuple, Optional + +import torch +import torch.fx +from torch._dispatch.python import enable_python_dispatcher +from torch._guards import detect_fake_mode +from torch._prims_common import definitely_contiguous_for_memory_format +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx._compatibility import compatibility +from torch.fx.node import map_aggregate, Node + + +__all__ = ["TensorMetadata", "ShapeProp"] + + +@compatibility(is_backward_compatible=True) +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + # General Tensor metadata + shape: torch.Size + dtype: torch.dtype + requires_grad: bool + stride: tuple[int, ...] + memory_format: Optional[torch.memory_format] + + # Quantization metadata + is_quantized: bool + qparams: dict[str, Any] + + +# When include_contiguity is True, we will set contiguity when its always true for the tensor. +# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3). +# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous, +# contiguous, and unknown). +def _extract_tensor_metadata( + result: torch.Tensor, include_contiguity=True +) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() if not is_sparse_any(result) else () + + memory_format = None + + if include_contiguity and not is_sparse_any(result): + memory_formats = { + torch.contiguous_format, + torch.channels_last, + torch.channels_last_3d, + } + for query_format in memory_formats: + if definitely_contiguous_for_memory_format( + result, memory_format=query_format + ): + memory_format = query_format + break + + is_quantized = result.is_quantized + qparams: dict[str, Any] = {} + if is_quantized: + qscheme = result.qscheme() + qparams["qscheme"] = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + qparams["scale"] = result.q_scale() # type: ignore[assignment] + qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + }: + # In this branch, scale and zero_point are expected to be tensors, + # we store the values as immutable_list in TensorMetadata for + # easier serialization downstream + qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] + qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] + qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] + + return TensorMetadata( + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams + ) + + +@compatibility(is_backward_compatible=True) +class ShapeProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and + record the shape and type of the result + into the corresponding node. + + Example: + In this example, we record the shape + and data type of a module given + an example input ``torch.randn(50, D_in)``. + We print the name, shape and dtype of each node. + + class TwoLayerNet(torch.nn.Module): + def __init__(self, D_in, H, D_out): + super().__init__() + self.linear1 = torch.nn.Linear(D_in, H) + self.linear2 = torch.nn.Linear(H, D_out) + def forward(self, x): + h_relu = self.linear1(x).clamp(min=0) + y_pred = self.linear2(h_relu) + return y_pred + N, D_in, H, D_out = 64, 1000, 100, 10 + x = torch.randn(N, D_in) + y = torch.randn(N, D_out) + model = TwoLayerNet(D_in, H, D_out) + gm = torch.fx.symbolic_trace(model) + sample_input = torch.randn(50, D_in) + ShapeProp(gm).propagate(sample_input) + + for node in gm.graph.nodes: + print(node.name, node.meta['tensor_meta'].dtype, + node.meta['tensor_meta'].shape) + + The output of this code is: + + x torch.float32 torch.Size([50, 1000]) + linear1 torch.float32 torch.Size([50, 100]) + clamp_1 torch.float32 torch.Size([50, 100]) + linear2 torch.float32 torch.Size([50, 10]) + output torch.float32 torch.Size([50, 10]) + + Args: + module (GraphModule): The module to be executed + fake_mode (FakeTensorMode): A fake mode for copying the gm + + """ + + def __init__(self, gm, fake_mode=None): + super().__init__(gm) + if fake_mode is None: + fake_mode = detect_fake_mode() + if fake_mode is not None: + from torch._dynamo.utils import deepcopy_to_fake_tensor + + # Note: + # We need fake execution cause the inputs are fake, however, we cannot fakify the module + # - because we need to write to the tensor_meta of the real module. So we fakify to + # produce a result (L131 below), to extract tensor meta, and then keep going. + # + # If we were to fakify, we would write to the wrong node, and then downstream fusion + # would be missing the tensor_meta. + # + # See torch/_inductor/overrides.py for where this is called upstream of fusion. + self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode) + self.fake_mode = fake_mode + else: + self.fake_module = None + self.fake_mode = None + + self.real_module = self.module + + def run_node(self, n: Node) -> Any: + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + ) + + try: + if self.fake_module is not None: + # Hacky swap. Alternatively, we could do this with overriding + # call_module and get_attr. + self.module = self.fake_module + try: + if self.fake_mode is not None: + with self.fake_mode, enable_python_dispatcher(): + result = super().run_node(n) + rebind_unbacked(self.fake_mode.shape_env, n, result) + else: + result = super().run_node(n) + finally: + self.module = self.real_module + except Exception as e: + traceback.print_exc() + raise RuntimeError( + f"ShapeProp error for: node={n.format_node()} with meta={n.meta}" + ) from e + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return _extract_tensor_metadata(obj) + else: + return obj + + meta = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta["tensor_meta"] = meta + + if self.fake_mode: + if (shape_env := self.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, result) + ): + n.meta["unbacked_bindings"] = symbol_to_path + + n.meta["type"] = type(result) + return result + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + if self.fake_mode is not None: + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + else: + fake_args = args + return super().run(*fake_args) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/split_module.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/split_module.py new file mode 100644 index 0000000000000000000000000000000000000000..413584070d1335e3c456b73c70f29e4881e378b9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/split_module.py @@ -0,0 +1,639 @@ +# mypy: allow-untyped-defs +import inspect +import logging +from collections import OrderedDict +from typing import Any, Callable, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx._utils import lazy_format_graph_code +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node + + +__all__ = ["Partition", "split_module"] +log = _LOGGER = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=True) +class Partition: + def __init__(self, name: str): + self.name: str = name + self.submod_name = f"submod_{name}" + self.node_names: list[str] = [] + self.inputs: dict[str, None] = {} + self.outputs: dict[str, None] = {} + self.dependencies: dict[str, None] = {} + self.dependents: dict[str, None] = {} + self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + self.environment: dict[Node, Node] = {} + self.targets: dict[str, Any] = {} + + def __repr__(self) -> str: + return ( + f"name: {self.name},\n" + f" nodes: {self.node_names},\n" + f" inputs: {self.inputs},\n" + f" outputs: {self.outputs},\n" + f" partitions depended on: {self.dependencies},\n" + f" partition dependents: {self.dependents}" + ) + + +def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any: + attr_val = mod + for atom in qualname.split("."): # type: ignore[union-attr] + if not hasattr(attr_val, atom): + raise AttributeError(f"Node target {qualname} not found!") + attr_val = getattr(attr_val, atom) + return attr_val + + +# Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[Node], int], + qualname_map: Optional[dict[str, str]] = None, + keep_original_order: Optional[bool] = False, + keep_original_node_name: Optional[bool] = False, + keep_original_input_name: bool = True, +): + """ + Creates subgraphs out of main graph + + Args: + m (GraphModule): Graph module to split + root_m (torch.nn.Module): root nn module. Not currently used. Included + because the root nn module is usually transformed via + torch.fx._symbolic_trace.symbolic_trace (see example below) + split_callback (Callable[[Node], int]): Callable function + that maps a given Node instance to a numeric partition identifier. + split_module will use this function as the policy for which operations + appear in which partitions in the output Module. + qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a + mapping from new target names in the module after split to old target + names in the original module. + keep_original_order: Optional[bool]: keep the original order of the GraphModule + or use the Topological order of the new constructed GraphModule + keep_original_node_name: Optional[bool]: If the partitioned graphs should + have the same node names as the original graph. + keep_original_input_name: bool: If the partitioned graphs should + have the same input names as the original graph. + + Returns: + GraphModule: the module after split. + + Example: + + This is a sample setup: + + import torch + from torch.fx.symbolic_trace import symbolic_trace + from torch.fx.graph_module import GraphModule + from torch.fx.node import Node + from torch.fx.passes.split_module import split_module + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + + def mod_partition(node: Node): + global partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + + # split module in module with submodules + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) + + Output looks like this. Original graph is broken into partitions + + > print(module_with_submodules) + GraphModule( + (submod_0): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_1): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_2): GraphModule() + ) + + def forward(self, x, y): + param = self.param + submod_0 = self.submod_0(x, param, y); x = param = y = None + getitem = submod_0[0] + getitem_1 = submod_0[1]; submod_0 = None + submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None + getitem_2 = submod_1[0] + getitem_3 = submod_1[1]; submod_1 = None + submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None + return submod_2 + + Output of split module is the same as output of input traced module. + This is an example within a test setting: + + > orig_out = my_module_traced(x, y) + > submodules_out = module_with_submodules(x, y) + > self.assertEqual(orig_out, submodules_out) + True + """ + + log.debug( + "%s", + lazy_format_graph_code("pre split_module", m, colored=True), + ) + + def construct_graph( + node: Node, + base_mod_env: dict[str, Node], + base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule], + ): + if node.op == "placeholder": + default_value = ( + node.args[0] if len(node.args) > 0 else inspect.Signature.empty + ) + if keep_original_node_name: + args = ( + () if default_value is inspect.Signature.empty else (default_value,) + ) + base_mod_env[node.name] = base_mod_graph.create_node( + "placeholder", + node.name, + args=args, # type: ignore[arg-type] + type_expr=node.type, + ) + else: + base_mod_env[node.name] = base_mod_graph.placeholder( + node.target, # type: ignore[arg-type] + type_expr=node.type, + default_value=default_value, + ) + base_mod_env[node.name].meta = node.meta.copy() + elif node.op == "get_attr": + base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type] + base_mod_env[node.name].meta = node.meta.copy() + assert isinstance(node.target, str) + attr_val = _get_attr_from_qualname(m, node.target) + base_mod_attrs[node.target] = attr_val # type: ignore[index] + return base_mod_env, base_mod_attrs + + import sympy + + partitions: dict[str, Partition] = {} + orig_nodes: dict[str, Node] = {} + symbol_to_node: dict[sympy.Symbol, Node] = {} + + def record_cross_partition_use(def_node: Node, use_node: Optional[Node]): + from torch.fx.experimental.symbolic_shapes import free_symbols + + defined = getattr(def_node, "_fx_partition", None) + used = getattr(use_node, "_fx_partition", None) + + log.debug( + "record_cross_partition_use %s (%s) %s (%s)", + def_node.name, + defined, + use_node.name if use_node is not None else "-", + used, + ) + + if defined != used: + if defined is not None: + def_partition = partitions[defined] + def_partition.outputs.setdefault(def_node.name) + if used is not None: + def_partition.dependents.setdefault(used) + + if used is not None: + use_partition = partitions[used] + use_partition.inputs.setdefault(def_node.name) + # We have made def_node an input to the use_partition. If + # this input has symbolic symbols in its size, those also must + # be made as inputs to the partition + if (def_val := def_node.meta.get("example_value")) is not None: + for s in sorted(free_symbols(def_val), key=str): + s_node = symbol_to_node[s] + use_partition.inputs.setdefault(s_node.name) + if symbol_to_node[s].op != "placeholder": + # If the node that defines the symbol is not a + # placeholder, we must make it an output of the + # partition. Note that this may be in a different + # partition than defined! Although, this doesn't + # really make a difference for correctness, since + # defined is guaranteed to have the symbol in + # scope and can return it; you just get less + # optimal codegen in this case. + s_defined = getattr(s_node, "_fx_partition", None) + if s_defined is not None: + s_def_partition = partitions[s_defined] + s_def_partition.outputs.setdefault(s_node.name) + s_def_partition.dependents.setdefault(used) + if defined is not None: + use_partition.dependencies.setdefault(defined) + + def instantiate_node_partition_mapping(node): + partition_name = str(split_callback(node)) + log.debug( + "instantiate_node_partition_mapping %s (%s)", node.name, partition_name + ) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + # Global State Nodes are nodes which by their global state effects, + # "taint" all downstream nodes while they are active. + GLOBAL_STATE_NODES = [ + torch.amp._enter_autocast, + torch.amp._exit_autocast, + torch._C._set_grad_enabled, + ] + + # For grad regions: + # ------------------------ + # 1. first region: we do nothing + # 2. subsequent regions: we insert the set_grad at the beginning + grad_regions: OrderedDict[Node, set[int]] = OrderedDict() + + # For autocast regions: + # ------------------------ + # 1. first region: we will only insert the _exit at the end + # 2. intermediate regions: we will insert both the + # _enter at the beginning and _exit at the end + # 3. last region: we will only insert _enter at the beginning + # We will do so in the order in which the autocasts were instantiated. + autocast_regions: OrderedDict[Node, set[int]] = OrderedDict() + autocast_exits: dict[Node, Optional[Node]] = {} + + active_grad = None + active_autocasts = set() + + for node in m.graph.nodes: + # This will prefer placeholder bindings, because those come first. + # This is a little dangerous though: it is possible that an unbacked + # symbol is used without any binding site for it, in which case we + # will get a KeyError not able to find it. I'd like to fix this by + # having passes.runtime_assert establish some invariants that I can + # rely on later, but this needs some extra work. Quick fix first. + # See https://github.com/pytorch/pytorch/issues/130534 + if ( + (val := node.meta.get("example_value")) is not None + and isinstance(val, (torch.SymInt, torch.SymFloat)) + and isinstance(s0 := val.node.expr, sympy.Symbol) + and s0 not in symbol_to_node + ): + symbol_to_node[val.node.expr] = node + + if node.op in ["placeholder", "get_attr", "output"]: + continue + + instantiate_node_partition_mapping(node) + + if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: + if node.target == torch._C._set_grad_enabled: + assert len(node.args) == 1 + assert isinstance(node.args[0], bool) + active_grad = node + grad_regions[active_grad] = set({split_callback(node)}) + elif node.target == torch.amp._enter_autocast: + # Should all be python constants + assert all(not isinstance(arg, Node) for arg in node.args) + active_autocasts.add(node) + autocast_regions[node] = set({split_callback(node)}) + autocast_exits[node] = None + elif node.target == torch.amp._exit_autocast: + assert len(node.args) == 1 + autocast_regions[node.args[0]].add(split_callback(node)) + active_autocasts.remove(node.args[0]) + autocast_exits[node.args[0]] = node + + if active_grad is not None: + grad_regions[active_grad].add(split_callback(node)) + + for a in active_autocasts: + autocast_regions[a].add(split_callback(node)) + + assert all(v is not None for v in autocast_exits.values()), "autocast must exit" + + autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} + grad_regions = {k: sorted(v) for k, v in grad_regions.items()} + + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug("autocast_regions: %s", autocast_regions) + _LOGGER.debug("grad_regions: %s", grad_regions) + + assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions) + + # split nodes into partitions + highest_partition = -1 + for node in m.graph.nodes: + orig_nodes[node.name] = node + + # TODO currently placeholders/parameters aren't put into random partitions, + # rather they're added to the graphs where they are used down below + if node.op in ["placeholder", "get_attr"]: + continue + if node.op == "output": + torch.fx.graph.map_arg( + node.args[0], lambda n: record_cross_partition_use(n, None) + ) + continue + + if assert_monotonically_increasing: + pid = split_callback(node) + assert highest_partition <= pid, ( + "autocast or set_grad_enabled require monotonically increasing partitions:" + f"highest: {highest_partition}, this node's: {pid}" + ) + highest_partition = pid + + # do not capture cross-partition dependencies for global state nodes as they will be + # self-contained - their setup and unwind will be isolated to each partition submodule. + if node.target not in GLOBAL_STATE_NODES: + torch.fx.graph.map_arg( + node.args, lambda def_node: record_cross_partition_use(def_node, node) + ) + torch.fx.graph.map_arg( + node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) + ) # noqa: B950 + + original_partition_order = list(partitions.keys()) + # find partitions with no dependencies + root_partitions: list[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.dependencies): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: list[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].dependents: + partitions[dependent].dependencies.pop(root_partition) + if not partitions[dependent].dependencies: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # Enter prelude + for regions_mapping in [autocast_regions, grad_regions]: + for node, regions in regions_mapping.items(): + assert len(regions) > 0 + partitions[str(regions[0])].environment[node] = node + for r in regions[1:]: + partition = partitions[str(r)] + new_node = partition.graph.create_node( + op=node.op, + target=node.target, + args=tuple(arg for arg in node.args), + kwargs={}, + type_expr=node.type, + ) + new_node.meta = ( + node.meta.copy() + ) # is it really a good idea to copy this? + partition.environment[node] = new_node + + # add placeholders to partition inputs + for partition_name in sorted_partitions: + partition = partitions[partition_name] + new_inputs: dict[str, None] = {} + + counter = 0 + + for inp in partition.inputs: + orig_node = orig_nodes[inp] + # We don't pass in get_attr nodes as inputs to the partition, but + # instead set them as targets and use getattr within the module + + def add_placeholder(): + if keep_original_input_name: + name = inp + else: + nonlocal counter + name = f"arg_{counter}" + counter += 1 + placeholder = partition.graph.placeholder( + name, + type_expr=orig_nodes[inp].type, + ) + new_inputs[inp] = None + return placeholder + + if orig_node.op == "get_attr": + assert isinstance(orig_node.target, str) + + orig_attr = _get_attr_from_qualname(m, orig_node.target) + if isinstance(orig_attr, torch.nn.Module): + placeholder = partition.graph.get_attr(orig_node.target) + partition.targets[orig_node.target] = orig_attr + else: + placeholder = add_placeholder() + else: + placeholder = add_placeholder() + placeholder.meta = orig_nodes[inp].meta.copy() + partition.environment[orig_nodes[inp]] = placeholder + partition.inputs = new_inputs + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, "_fx_partition"): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg( + node.kwargs, lambda n: environment[n] + ) + + if node.op not in ["call_module", "get_attr"]: + target = node.target + else: + target_attr = _get_attr_from_qualname(m, node.target) + target = node.target.replace(".", "_") + partition.targets[target] = target_attr + # Fill in the passed-in mapping from new qualname to old qualname + if qualname_map is not None: + # When creating the split module later, the submodules will have + # path prefix matching the corresponding partition's submod_name + qualname = f"{partition.submod_name}.{target}" + qualname_map[qualname] = node.target + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + name = node.name if keep_original_node_name else None + new_node = partition.graph.create_node( + op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs, + type_expr=node.type, + name=name, + ) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Exit epilogue + for regions_mapping in [autocast_regions]: + for node in reversed(regions_mapping): + regions = regions_mapping[node] + assert len(regions) > 0 + for r in regions[:-1]: + partition = partitions[str(r)] + exit_node = autocast_exits[node] + assert exit_node is not None, "Missing exit node" + new_node = partition.graph.create_node( + op=exit_node.op, + target=exit_node.target, + args=(partition.environment[node],), + kwargs={}, + type_expr=exit_node.type, + ) + new_node.meta = ( + exit_node.meta.copy() + ) # is it really a good idea to copy this? + + # original module environment dict mapping node names to nodes + orig_mod_env: dict[str, Node] = {} + # Set up values to construct base module + base_mod_env: dict[str, Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {} + if not keep_original_order: + for node in m.graph.nodes: + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + + else: + # Go through the graph to construct the mapping dict + for node in m.graph.nodes: + orig_mod_env[node.name] = node + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order or original order specified by keep_original_order + + construct_order_partitions = ( + sorted_partitions if not keep_original_order else original_partition_order + ) + + already_constructed_attr_nodes = set() + + # We actually need to insert the placeholder nodes in the original order + # otherwise graph signature will be wrong. + original_order = [node for node in m.graph.nodes if node.op == "placeholder"] + + for partition_name in construct_order_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple( + partition.environment[orig_nodes[name]] for name in partition.outputs + ) + + # skip output node generation if there are no output values + num_output_vals = len(output_vals) + if num_output_vals == 1: + partition.graph.output(output_vals[0]) + elif num_output_vals > 1: + partition.graph.output(output_vals) + else: + # Invariant - Graph should always have an output node. + partition.graph.output(()) + + if keep_original_order: + # first get the attr nodes required by this partition + orig_mod_attr_nodes: list[Node] = [ + orig_mod_env[key] + for key in partition.inputs + if key not in original_order + ] + + for node in original_order: + if node in already_constructed_attr_nodes: + continue # already added this attr to the base graph + base_mod_env, _based_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + already_constructed_attr_nodes.add(node) + + # Construct GraphModule for this partition + for node in orig_mod_attr_nodes: # type: ignore[attr-defined] + if node in already_constructed_attr_nodes: + continue + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + already_constructed_attr_nodes.add(node) + + base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module( + partition.submod_name, + tuple(base_mod_env[name] for name in partition.inputs), + ) + + num_outputs = len(partition.outputs) + if num_outputs > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + elif num_outputs == 1: + base_mod_env[next(iter(partition.outputs))] = output_val + + # When keep_original_order=True and if the graph doesn't have any + # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs` + # are never populated. + # For this case, we call `construct_graph` here which takes care of updating them. + if keep_original_order and not base_mod_env: + for node in m.graph.nodes: + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + + # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned. + for node in m.graph.nodes: + if node.op == "output": + base_mod_graph.output( + torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) + ) # noqa: B950 + + ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + log.debug( + "%s", + lazy_format_graph_code("post split_module", ret, colored=True), + ) + return ret diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/split_utils.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/split_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..079b1b4364bd8c974a5bb8bc46955dc2e84f0c6b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/split_utils.py @@ -0,0 +1,307 @@ +# mypy: allow-untyped-defs +import copy +from dataclasses import dataclass, field +from typing import Optional, Union + +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import map_arg +from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module + +from .tools_common import NodeList + + +__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] + + +@compatibility(is_backward_compatible=False) +def getattr_recursive(obj, name): + for layer in name.split("."): + if hasattr(obj, layer): + obj = getattr(obj, layer) + else: + return None + return obj + + +@compatibility(is_backward_compatible=False) +def setattr_recursive(obj, attr, value): + if "." not in attr: + setattr(obj, attr, value) + else: + layer = attr.split(".") + setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) + + +@compatibility(is_backward_compatible=False) +@dataclass +class Component: + """ + A component serves as a container for a subgraph we want to create afterwards. + """ + + graph: torch.fx.Graph + order: int + name: str + + # Stores the placeholder nodes in `graph`. + input_placeholders: list = field(default_factory=list) + + # Store the nodes in original graph that are placeholder in `graph`. + orig_inputs: list = field(default_factory=list) + + # Store the nodes in original graph that are outputs in `graph`. + orig_outputs: list = field(default_factory=list) + + # Mapping from get_attr node in original graph to get_attr node in `graph`. + getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) + constructor_args: list[str] = field(default_factory=list) + gm: Optional[torch.fx.GraphModule] = None + + +@compatibility(is_backward_compatible=False) +def split_by_tags( + gm: torch.fx.GraphModule, + tags: list[str], + return_fqn_mapping: bool = False, + return_tuple: bool = False, + GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule, +) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]: + """ + Splits a GraphModule using tags on its graph nodes. We honor the order of + tags. For example, we have tags = ["a", "b", "c"], the function will create + the initial submodules in the order of "a", "b", "c". + + To set a tag: + gm.graph.nodes[idx].tag = "mytag" + + This will result in all nodes with the same tag being extracted and placed in their + own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder + and output nodes are created when needed while get_attr nodes get copied to submodules + where they are used. + + Given the following module def: + + class SimpleModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(...) + self.linear2 = torch.nn.Linear(...) + self.linear3 = torch.nn.Linear(...) + + def forward(self, in1, in2): + r1 = self.linear1(in1) + r2 = self.linear2(in2) + r3 = torch.cat([r1, r2]) + return self.linear3(r3) + + Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: + + ro: + def forward(self, in1): + self = self.root + linear1 = self.linear1(in1) + return linear1 + + main: + def forward(self, in2, linear1): + self = self.root + linear2 = self.linear2(in2) + cat_1 = torch.cat([linear1, linear2]) + linear3 = self.linear3(cat_1) + return linear3 + + main: + def forward(self, in1, in2): + self = self.root + ro_0 = self.ro_0(in1) + main_1 = self.main_1(in2, ro_0) + return main_1 + + Returns: + split_gm: torch fx graph after split + orig_to_split_fqn_mapping: a map between the original fqn and the fqn + after split for call_module and get_attr. + """ + + def flatten(x: torch.fx.node.Argument) -> NodeList: + """ + Stores nodes in x to a list and returns the list. + """ + r: NodeList = [] + map_arg(x, r.append) + return r + + # Mapping from node in original module to node in created submodule. + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + + # Mapping from node in original module or created submodules to + # corresponding component. + node_to_component: dict[torch.fx.Node, Component] = {} + + # Mapping from tag to the corresponding component. + tag_to_component: dict[str, Component] = {} + + # Stores all components. + all_components: list[Component] = [] + + # Stores nodes that will be used in main graph. + used_in_main: dict[torch.fx.Node, None] = {} + + # Main graph after split. + main_g = torch.fx.Graph() + + # Mapping from node in original module to node in main graph after split. + main_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + + # Output node of original module. + output_node: Optional[torch.fx.Node] = None + + # Create a component for each tag, we don't expect to create other components afterwards. + for tag in tags: + comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") + all_components.append(comp) + tag_to_component[tag] = comp + + # Traverse the nodes in original graph and take care of them. + for node in gm.graph.nodes: + if node.op == "output": + if output_node is not None: + raise RuntimeError("Multiple output nodes in graph!") + output_node = node + continue + + # Placeholders in the original graph get copied to main graph. + if node.op == "placeholder": + main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) + main_remapping[node].meta = copy.copy(node.meta) + continue + + # Get_attr nodes are ignored because we are not tagging them. + # Instead, we copy them directly to the submodules use them afterwards. + if node.op == "get_attr": + continue + + # Now we process callable nodes which are nodes with op of call_module, + # call_function or call_method. Every callable nodes should be tagged. + assert hasattr(node, "tag"), f"Node does not have tag: {node.format_node()}" + + upstream_components = [ + node_to_component[x] + for x in flatten(node.args) + flatten(node.kwargs) + if x.op not in {"placeholder", "get_attr"} + ] + + comp = tag_to_component[node.tag] + node_to_component[node] = comp + + # Max order of upperstream components. + mx = max((c.order for c in upstream_components), default=0) + + # Expect the component for `node` has higher order then its upstream components. + assert comp.order >= mx, ( + f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}" + ) + + # Map a input of `node` to nodes in the component's graph. + def remap_func(x): + # If input is a get_attr node, copy it to current component's graph. + # Returns the get_attr node in current component's graph. + if x.op == "get_attr": + if x not in comp.getattr_maps: + comp.getattr_maps[x] = comp.graph.get_attr( + x.target, type_expr=x.type + ) + comp.getattr_maps[x].meta = copy.copy(x.meta) + return comp.getattr_maps[x] + + # If input is not a placeholder, it should have been put into a component + # already. If it's the current component then we return the corresponding + # node in the component. + if x.op != "placeholder" and node_to_component[x] == comp: + return node_remapping[x] + + # If input is a placeholder or it's in other components, we want to make it + # as a placeholder in current component's graph. + if x not in comp.orig_inputs: + comp.orig_inputs.append(x) + placeholder = comp.graph.placeholder(x.name, type_expr=x.type) + placeholder.meta = copy.copy(x.meta) + comp.input_placeholders.append(placeholder) + used_in_main[x] = None + + return comp.input_placeholders[comp.orig_inputs.index(x)] + + n = comp.graph.node_copy(node, remap_func) + n.tag = node.tag # type: ignore[attr-defined] + node_remapping[node] = n + node_to_component[n] = comp + + if output_node is None: + raise RuntimeError("Graph had no output node!") + + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + # We don't need components mapping for nodes of type "get_attr" + # that are consumed by the output. Only need to make sure we create + # corresponding counterparts in the resulting graph. + main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) + else: + # All component results consumed by the output node should be + # marked as "used in main". + used_in_main[x] = None + + # If a node is used in main graph then we mark it as an output in the component + # it belongs to. + for n in used_in_main: + if n.op != "placeholder": + node_to_component[n].orig_outputs.append(n) + + # Now we create a graphmodule for each component. + orig_to_split_fqn_mapping: dict[str, str] = {} + for comp in all_components: + outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) + + if return_tuple: + comp.graph.output(outs) + else: + # Take care of the args of FX output node. If there's a single + # output then the output node args is like (output_single), else + # if there're multiple outputs then the output node args is like + # ((output_0, output_1, ...)). + comp.graph.output(outs[0] if len(outs) == 1 else outs) + + comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module( + gm, subgraph=comp.graph, comp_name=comp.name + ) + orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping) + + # Create a call_module node in main graph. + main_node = main_g.call_module( + comp.name, + args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), + kwargs=None, + ) + + if len(outs) == 1 and not return_tuple: + main_remapping[comp.orig_outputs[0]] = main_node + else: + for i, o in enumerate(comp.orig_outputs): + # Use Proxy to record getitem access. + main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index] + + main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) + main_root = HolderModule({comp.name: comp.gm for comp in all_components}) + main_g._codegen = gm.graph._codegen + + # If the output nodes consumes get_attr directly in the original graph, + # then we need to make sure get_attr is copied to the new graph. + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] + + result_gm = GraphModuleCls(main_root, main_g) + if return_fqn_mapping: + return result_gm, orig_to_split_fqn_mapping + + return result_gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1d9ee616b2a9a93554f727fdafb0f7e104fb54 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/splitter_base.py @@ -0,0 +1,925 @@ +# mypy: allow-untyped-defs +import argparse +import copy +import logging +from collections import defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg +from torch.fx.passes.graph_manipulation import get_size_of_node + +from .graph_drawer import FxGraphDrawer +from .operator_support import get_node_target, OperatorSupportBase +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + is_node_output_tensor, + NodeList, + NodeSet, + Tensors, +) + + +__all__ = [ + "FxNetAccNodesFinder", + "FxNetSplitterInternalError", + "Subgraph", + "SplitResult", + "generate_inputs_for_submodules", +] +_LOGGER = logging.getLogger(__name__) + +DEFAULT_MIN_ACC_MODULE_SIZE = 1 +DEFAULT_SKIP_FUSION = False +DEFAULT_ALLOW_NON_TENSOR = False + + +class _SplitterSettingBase: + def __init__( + self, + min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, + skip_fusion=DEFAULT_SKIP_FUSION, + allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, + max_acc_splits: int = -1, + ): + parser = argparse.ArgumentParser() + parser.add_argument( + "--min-acc-module-size", + "--min_acc_module_size", + required=False, + type=int, + help="Minimum size limit of an accelerator subgraph.", + ) + parser.add_argument( + "--max-acc-splits", + "--max_acc_splits", + required=False, + type=int, + help="Enforce a maximum number of split subgraphs.", + ) + parser.add_argument( + "--skip-fusion", + "--skip_fusion", + default=False, + action="store_true", + help="If true then no fusion groups. Fusion group is used to " + "enforce no non-tensor data flow between submodules. If we don't " + "have this constrain, setting this to false is recommended as it " + "can reduce overhead.", + ) + parser.add_argument( + "--allow-non-tensor", + "--allow_non_tensor", + default=False, + action="store_true", + help="For some backends non-tensor data flow between cpu and them " + "are not allowed. Therefore, if a node supported by accelerator but " + "it has non-tensor inputs or outputs to a cpu node we would want to " + "consider it as a cpu node during splitting. However, for some backends " + "we might not care about non-tensor data flow and we can set this option " + "to true to disable the functionality that prevent non-tensor data flow.", + ) + args, _unknown = parser.parse_known_args() + + self.min_acc_module_size: int = ( + args.min_acc_module_size + if args.min_acc_module_size + else min_acc_module_size + ) + self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion + self.allow_non_tensor: bool = ( + args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + ) + self.max_acc_splits: int = max_acc_splits + + +@compatibility(is_backward_compatible=False) +class FxNetAccNodesFinder: + """ + Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor + input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. + + I.e. if we have a chain: + + ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 + + where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. + + This behavior can be turned off by passing allow_non_tensor=True. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: OperatorSupportBase, + allow_non_tensor: bool, + ): + self.module = module + self.operator_support = operator_support + self.allow_non_tensor = allow_non_tensor + self.acc_nodes: NodeSet = set() + + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): + """ + Transitively excludes nodes from ACC supported set. + For every node in the worklist: + - removes its downstream ACC nodes from ACC supported set, + - if any downstream ACC node produces non-tensor output, + then it gets added into the worklist. + """ + while cpu_worklist: + node = cpu_worklist.pop(0) + + for user in node.users: + if user in self.acc_nodes: + self.acc_nodes.remove(user) + if not is_node_output_tensor(user): + cpu_worklist.append(user) + + def reduce_acc_nodes_non_tensor_input(self): + """ + Excludes nodes from ACC supported set that have direct + upstream CPU nodes that produce non-tensor outputs. + """ + non_tensor_cpu_nodes: NodeList = [] + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + if node in self.acc_nodes: + continue + if is_node_output_tensor(node): + continue + non_tensor_cpu_nodes.append(node) + + self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) + + def reduce_acc_nodes_non_tensor_output(self): + """ + Excludes nodes from ACC supported set that produce non-tensor + outputs and have downstream CPU nodes. + """ + while True: + new_cpu_nodes: NodeList = [] + + for acc_node in self.acc_nodes: + if is_node_output_tensor(acc_node): + continue + for user in acc_node.users: + if user not in self.acc_nodes: + new_cpu_nodes.append(acc_node) + break + + if not new_cpu_nodes: + break + + for new_cpu_node in new_cpu_nodes: + self.acc_nodes.remove(new_cpu_node) + + self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) + + def __call__(self) -> NodeSet: + submodules = dict(self.module.named_modules()) + self.acc_nodes = { + n + for n in self.module.graph.nodes + if n.op in CALLABLE_NODE_OPS + and self.operator_support.is_node_supported(submodules, n) + } + + if not self.allow_non_tensor: + self.reduce_acc_nodes_non_tensor_input() + self.reduce_acc_nodes_non_tensor_output() + + return self.acc_nodes + + +@compatibility(is_backward_compatible=False) +class FxNetSplitterInternalError(Exception): + pass + + +@compatibility(is_backward_compatible=False) +@dataclass +class Subgraph: + is_acc: bool + nodes: NodeList + device_ordinal: Optional[int] = None + + +@compatibility(is_backward_compatible=False) +class SplitResult(NamedTuple): + """ + Stores the results of the splitter. + + Attributes: + split_module: root module after splitting. + submodule_inputs: a dict that maps submodule name to its inputs. + non_acc_submodule_prefix: the prefix for non acc submodules. For + acc submodule the prefix is alwasy "_run_on_acc_". + """ + + split_module: torch.fx.GraphModule + submodule_inputs: dict[str, Any] + non_acc_submodule_prefix: str + + +@compatibility(is_backward_compatible=False) +def generate_inputs_for_submodules( + model: torch.nn.Module, + inputs: Sequence[Any], + target_submodules: Iterable[str], + deepcopy: bool = False, +) -> dict[str, Any]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_inputs): + results[submodule_to_names[module]] = ( + copy.deepcopy(module_inputs) if deepcopy else module_inputs + ) + + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append(mod.register_forward_pre_hook(pre_forward)) + + def clean_up_handles(): + for h in handles: + h.remove() + + try: + with torch.no_grad(): + model(*inputs) + except Exception as e: + clean_up_handles() + raise e + + clean_up_handles() + return results + + +class _SplitterBase: + """ + Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. + Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. + Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. + + Given the following graph: + ==> b ==> + // \\ + a d + \\ // + ==> c ==> + + class SimpleModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.cos(a) + d = b + c + return d + + and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, + we will get the following split result: + + main: + def forward(self, a): + run_on_acc_0_0 = self._run_on_acc_0_0(a) + getitem = run_on_acc_0_0[0] + getitem_1 = run_on_acc_0_0[1] + run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) + return run_on_cpu_1_1 + + _run_on_acc_0_0: + def forward(self, a): + sin_1 = torch.sin(a) + cos_1 = torch.cos(a) + return (sin_1, cos_1) + + _run_on_cpu_1_1: + def forward(self, sin_1, cos_1): + add_1 = sin_1 + cos_1 + return add_1 + """ + + # PCIe bandwidth for the backend, default to 100 GB/s + PCIe_BW = 100 * 2**30 + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Sequence[Any], + operator_support: OperatorSupportBase, + settings: _SplitterSettingBase, + non_acc_submodule_name: str = "_run_on_cpu_", + return_tuple: bool = False, + nodes_finder: Optional[FxNetAccNodesFinder] = None, + ): + """ + Preprocesses graph before splitting: + - finds nodes supported by ACC, + - finds fusion groups for ACC nodes having non-tensor IO, + - builds a graph of direct dependencies, + - builds a map of fused nodes to their fusions. + As a result we get self.acc_nodes, self.deps and self.fusions. + """ + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + ShapeProp(self.module).propagate(*sample_input) + + self.settings = settings + self.operator_support = operator_support + self.sample_input = sample_input + if nodes_finder is None: + nodes_finder = FxNetAccNodesFinder( + self.module, self.operator_support, self.settings.allow_non_tensor + ) + self.acc_nodes = nodes_finder() + + if self.settings.skip_fusion: + self.fusions = {} + else: + self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() + + # Modify deps to add more deps for fused nodes + self.deps = self.find_deps() + self.update_deps_for_fusions() + + self.non_acc_submodule_name = non_acc_submodule_name + self._node_submodule_map: dict[str, str] = {} + self._return_tuple = return_tuple + + self.tags: list[str] = [] + + # =============================================================== + # Helpers for ctor and initial state + # =============================================================== + + def get_node_submodule_map(self) -> dict[str, str]: + """Returns a map from node name to submodule name, e.g. + node: main_module_impl_impl_over_arch_unary_multiple_embedding + _pooling_embedding_pooling_sparse_entity_equivalence_key + _proxy_embedding_bag + maps to submodule name of: _run_on_acc_1 + """ + return self._node_submodule_map + + def find_deps(self) -> dict[torch.fx.Node, NodeSet]: + """ + Builds a graph of node dependencies. Leaf nodes don't have any + dependencies and the "output" node doesn't have nodes depending on it. + + Resulting graph has only direct dependencies, i.e. there are no + transitive dependencies. + """ + deps: dict[torch.fx.Node, NodeSet] = defaultdict(set) + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op != "output": + deps[user].add(node) + return deps + + def update_deps_for_fusions(self): + """ + Updates graph of dependencies so that: + - nodes from the same fusion depend on the same set of outer nodes, + - outer nodes depending on a fusion depend on all nodes in that fusion. + """ + for node in self.fusions: + fusion = self.fusions[node] + for fused_neighbor in fusion: + self.deps[node].update(self.deps[fused_neighbor] - fusion) + + for user in fused_neighbor.users: + if user not in fusion: + self.deps[user].add(node) + + # =============================================================== + # Helpers for preview + # =============================================================== + + def _lower_model_to_backend( + self, mod: torch.fx.GraphModule, inputs: Tensors + ) -> torch.nn.Module: + """ + Lower the model to a backend. + """ + + return mod + + def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str: + """ + When an error occurs during lowering or running the lowered mod, we use this + function to find culprits in the `mod` that causes the error. + """ + + return "Unable to find a culprit because _find_culprit() function is not implemented." + + def _draw_graph_based_on_node_support( + self, mod: torch.fx.GraphModule, supported_nodes: NodeList + ): + color_map = { + "default": "AliceBlue", + "supported": "chartreuse1", + "unsupported": "crimson", + } + + class CustomDrawer(FxGraphDrawer): + def _get_node_style(self, node): + template = super()._get_node_style(node) + if node in supported_nodes: + template["fillcolor"] = color_map["supported"] + elif node.op in CALLABLE_NODE_OPS: + template["fillcolor"] = color_map["unsupported"] + else: + template["fillcolor"] = color_map["default"] + + return template + + drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) + dot_graph = drawer.get_main_dot_graph() + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined] + + def node_support_preview(self, dump_graph: bool = False): + submodules = dict(self.module.named_modules()) + + supported_nodes: NodeList = [] + supported_node_types = defaultdict(set) + unsupported_node_types = defaultdict(set) + + def get_dtype(arg): + tensor_meta = arg.meta.get("tensor_meta") + return getattr(tensor_meta, "dtype", None) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + target = get_node_target(submodules, node) + + # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. + arg_dtypes = [ + get_dtype(arg) if isinstance(arg, torch.fx.Node) else None + for arg in node.args + ] + + # Find last non-None element. If all elements are None, return max_len. + last_index = len(arg_dtypes) - next( + ( + i + for i, dtype in enumerate(reversed(arg_dtypes)) + if dtype is not None + ), + len(arg_dtypes), + ) + + # Strip None elements at the end. + arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) + kwarg_dtypes_tuple = tuple( + (k, get_dtype(arg)) + for k, arg in node.kwargs.items() + if isinstance(arg, torch.fx.Node) + ) + + if self.operator_support.is_node_supported(submodules, node): + supported_nodes.append(node) + supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + else: + unsupported_node_types[target].add( + (arg_dtypes_tuple, kwarg_dtypes_tuple) + ) + + if dump_graph: + self._draw_graph_based_on_node_support(self.module, supported_nodes) + + reports = "\nSupported node types in the model:\n" + for t, dtypes in supported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + reports += "\nUnsupported node types in the model:\n" + for t, dtypes in unsupported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + print(reports) + + # Return reports for testing purpose + return reports + + def split_preview(self, dump_graph: bool = False): + reports = "" + subgraphs = self.put_nodes_into_subgraphs() + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + for i, subgraph in enumerate(subgraphs): + reports += ( + f"_run_on_acc_{i}: " + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{i}: " + ) + reports += f"{len(subgraph.nodes)} node(s)\n" + + self.tag(subgraphs) + split_mod = self.split(remove_tag=True) + split_mod.eval() + + if dump_graph: + drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True) + dot_graphs = drawer.get_all_dot_graphs() + for name, dot_graph in dot_graphs.items(): + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw(f"{name}.dot") # type: ignore[attr-defined] + + max_qps: float = self.PCIe_BW + bottleneck_module = "" + + for node in split_mod.graph.nodes: + if node.op == "call_module" and "acc" in node.target: + reports += f"\nProcessing acc submodule {node.target}\n" + + submod = getattr(split_mod, node.target) + + def get_submod_inputs(main_mod, submod, example_inputs): + sub_inputs = None + + def get_inputs(self, inputs): + nonlocal sub_inputs + sub_inputs = inputs + + handle = submod.register_forward_pre_hook(get_inputs) + main_mod(*example_inputs) + handle.remove() + return sub_inputs + + submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) + ShapeProp(submod).propagate(*submod_inputs) + + total_input_bytes = 0 + total_output_bytes = 0 + + reports += "Checking inputs...\n" + for n in submod.graph.nodes: + if n.op == "placeholder": + if not is_node_output_tensor(n): + reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_input_bytes += get_size_of_node(submod, n)[0] + if n.op == "output": + output_node = n + + reports += "Checking outputs...\n" + + def get_bytes(node: torch.fx.Node): + nonlocal total_output_bytes + nonlocal reports + if not is_node_output_tensor(node): + reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_output_bytes += get_size_of_node(submod, node)[0] + + map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] + qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) + reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," + reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" + + if qps < max_qps: + max_qps = qps + bottleneck_module = node.target + + try: + lowered_submod = self._lower_model_to_backend(submod, submod_inputs) + except RuntimeError: + reports += "Run into an error during lowering!\n" + reports += self._find_culprit(submod, submod_inputs) + continue + + try: + lowered_submod(*submod_inputs) + except RuntimeError: + reports += "Run into an error during inference!\n" + reports += self._find_culprit(submod, submod_inputs) + else: + reports += "Lowering and running succeed!\n" + + reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," + reports += f" bottleneck is submodule {bottleneck_module}." + print(reports) + + # return the reports for testing purposes + return reports + + # =============================================================== + # Helpers for extend_acc_subgraph() method + # =============================================================== + + def find_reverse_deps( + self, tag_id: Optional[int] = None + ) -> dict[torch.fx.Node, NodeSet]: + """ + Builds reversed topological node dependencies, if tag_id is specified, + we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. + """ + result: dict[torch.fx.Node, NodeSet] = defaultdict(set) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): + result[node].add(user) + + return result + + def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]): + processed_node = set() + + for node, fusion in self.fusions.items(): + if node in processed_node: + continue + + new_dep = set() + + # Create a new dependency set which include all the + # dependencies of the nodes in the fusion group + for n in fusion: + new_dep.update(deps[n]) + + # Exclude nodes in the fusion + new_dep.difference_update(fusion) + + # Update dependency + for n in fusion: + deps[n] = new_dep + + for arg in n.all_input_nodes: + if arg not in fusion: + deps[arg].update(fusion) + + processed_node.add(n) + + def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: + """ + Finds parent nodes of the `tag` subgraph. + + Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph + and is not a placeholder, we consider it as the parent node of the subgraph. + """ + parent_nodes = set() + + for node in self.module.graph.nodes: + if node.op in CALLABLE_NODE_OPS and node.tag == tag: + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: + parent_nodes.add(arg) + + return parent_nodes + + def extend_acc_subgraph(self, tag: str): + """ + Extend the acc subgraph with `tag` going the reversed topological direction. + """ + # Dict that maps node to its users and ignore users that + # are in the subgraph that has greater tag + deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) + self.update_reverse_deps_for_fusions(deps) + + # Parent nodes of the subgraph + parent_nodes = self.find_parent_nodes_of_subgraph(tag) + + visited_nodes: NodeSet = set() + + while parent_nodes: + node = None + + # Find a acc node that depends on visited nodes only + for n in parent_nodes: + if deps[n] <= visited_nodes and n in self.acc_nodes: + node = n + break + + if node is None: + break + + # Put the node into `tag` subgraph + node.tag = tag # type: ignore[attr-defined] + parent_nodes.remove(node) + visited_nodes.add(node) + + # If node is in a fusion group, add all fusion buddies to parent nodes + if node in self.fusions: + for fusion_node in self.fusions[node]: + if fusion_node not in visited_nodes: + parent_nodes.add(fusion_node) + + # Add inputs of the node to parent nodes + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: + parent_nodes.add(arg) + + # =============================================================== + # Helpers for split() method + # =============================================================== + + def starter_nodes(self) -> tuple[NodeSet, NodeSet]: + """ + Finds nodes that consume module inputs or get_attr nodes. + """ + starter_cpu_nodes: NodeSet = set() + starter_acc_nodes: NodeSet = set() + for node in self.module.graph.nodes: + if node.op not in {"placeholder", "get_attr"}: + continue + for user in node.users: + if user in self.acc_nodes: + starter_acc_nodes.add(user) + else: + starter_cpu_nodes.add(user) + return starter_cpu_nodes, starter_acc_nodes + + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + # We start graph traversal from leaf nodes + current_cpu_nodes, current_acc_nodes = self.starter_nodes() + visited_nodes: NodeSet = set() + + # Determine which subgraph to start from based on which subgraph has + # 0-dep node + acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) + + current_subgraph_nodes: NodeList = [] + + # Result accumulator + subgraphs: list[Subgraph] = [] + while current_cpu_nodes or current_acc_nodes: + # Find the first node that should belong to the current subgraph and has all dependencies resolved + current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes + node = next( + (n for n in current_nodes if self.deps[n] <= visited_nodes), + None, + ) + + # If nothing was found, then it's time to flip the mode and start a new subgraph + if node is None: + if not current_subgraph_nodes: + raise FxNetSplitterInternalError("Subgraph can't be empty") + + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + acc_subgraph = not acc_subgraph + current_subgraph_nodes = [] + continue + + current_nodes.remove(node) + visited_nodes.add(node) + current_subgraph_nodes.append(node) + + # Add fusion buddies + if node in self.fusions: + if node in self.acc_nodes: + current_acc_nodes.update(self.fusions[node] - visited_nodes) + else: + current_cpu_nodes.update(self.fusions[node] - visited_nodes) + + # Put depending nodes into the queue + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + # Add downstream nodes + if user in self.acc_nodes: + current_acc_nodes.add(user) + else: + current_cpu_nodes.add(user) + + # Check if the last subgraph was not created + if current_subgraph_nodes: + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + + if not subgraphs: + raise FxNetSplitterInternalError("Couldn't create subgraphs") + + return subgraphs + + def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]: + """ + This pass finds ACC submodules with less than specified size and merges + them with adjacent CPU submodules. + """ + result: list[Subgraph] = [] + for subgraph in subgraphs: + if subgraph.is_acc: + if len(subgraph.nodes) >= self.settings.min_acc_module_size: + result.append(subgraph) + else: + print( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) + if result: + result[-1].nodes.extend(subgraph.nodes) + else: + subgraph.is_acc = False + result.append(subgraph) + else: + if result and not result[-1].is_acc: + result[-1].nodes.extend(subgraph.nodes) + else: + result.append(subgraph) + return result + + def tag(self, subgraphs: list[Subgraph]): + self.tags = [] + for subgraph in subgraphs: + tag = ( + f"_run_on_acc_{len(self.tags)}" + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{len(self.tags)}" + ) + self.tags.append(tag) + for node in subgraph.nodes: + if hasattr(node, "tag"): + raise FxNetSplitterInternalError(f"Node {node} was already tagged") + + node.tag = tag # type: ignore[attr-defined] + self._node_submodule_map[node.name] = tag + + def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: + split_module = split_by_tags( + self.module, self.tags, return_tuple=self._return_tuple + ) + if remove_tag: + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + del node.tag + return split_module # type: ignore[return-value] + + def __call__(self) -> torch.fx.GraphModule: + subgraphs = self.put_nodes_into_subgraphs() + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) + non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count + print( + f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs" + ) + self.tag(subgraphs) + return self.split() + + def generate_split_results(self) -> SplitResult: + split_module = self() + submodule_names = [] + for name, _mod in split_module.named_children(): + submodule_names.append(name) + if ( + self.settings.max_acc_splits > 0 + and len(submodule_names) > self.settings.max_acc_splits + ): + raise ValueError( + "Cannot fulfill max_acc_splits limit. " + "This may cause split fragmentation and " + "result in performance issues." + ) + + submodule_inputs = generate_inputs_for_submodules( + split_module, self.sample_input, submodule_names + ) + return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..157dc4017eda576f10793ef46b78cd97b0f5074b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/tests/test_pass_manager.py @@ -0,0 +1,56 @@ +import unittest + +from ..pass_manager import ( + inplace_wrapper, + PassManager, + these_before_those_pass_constraint, + this_before_that_pass_constraint, +) + + +class TestPassManager(unittest.TestCase): + def test_pass_manager_builder(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + pm.validate() + + def test_this_before_that_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + + # add unfulfillable constraint + pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) + + self.assertRaises(RuntimeError, pm.validate) + + def test_these_before_those_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + constraint = these_before_those_pass_constraint(passes[-1], passes[0]) + pm = PassManager([inplace_wrapper(p) for p in passes]) + + # add unfulfillable constraint + pm.add_constraint(constraint) + + self.assertRaises(RuntimeError, pm.validate) + + def test_two_pass_managers(self) -> None: + """Make sure we can construct the PassManager twice and not share any + state between them""" + + passes = [lambda x: 2 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm1 = PassManager() + for p in passes: + pm1.add_pass(p) + pm1.add_constraint(constraint) + output1 = pm1(1) + self.assertEqual(output1, 2**3) + + passes = [lambda x: 3 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm2 = PassManager() + for p in passes: + pm2.add_pass(p) + pm2.add_constraint(constraint) + output2 = pm2(1) + self.assertEqual(output2, 3**3) diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/tools_common.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/tools_common.py new file mode 100644 index 0000000000000000000000000000000000000000..212b094e86e3536c2d135be4e93fb6368871ce7a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/tools_common.py @@ -0,0 +1,319 @@ +# mypy: allow-untyped-defs +import collections +import operator +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.node import _get_qualified_name + + +__all__ = [ + "get_acc_ops_name", + "get_node_target", + "is_node_output_tensor", + "FxNetAccFusionsFinder", + "legalize_graph", +] + +Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]] +TensorOrTensors = Union[torch.Tensor, Tensors] +NodeList = list[torch.fx.Node] +NodeSet = set[torch.fx.Node] +Names = list[str] +CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} + + +@compatibility(is_backward_compatible=False) +def get_acc_ops_name(k): + if isinstance(k, str): + return k + elif k.__module__ and "acc_ops" in k.__module__: + return f"acc_ops.{k.__name__}" + else: + module = k.__module__.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module + return f"{module if module else ''}.{k.__name__}" + + +@compatibility(is_backward_compatible=False) +def get_node_target( + submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node +) -> str: + """ + Given a `node` returns its target typename. + + For "call_method" node, return node.target which is the name of that method being called. + This could potential lead to conflict but should be okay because normally it's on a tensor. + + For "call_function" node, return typename of node.target. + + For "call_module" node, return typename of the module that node.target point to. + + If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by + "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. + """ + + assert node.op in CALLABLE_NODE_OPS, ( + "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" + ) + + if node.op == "call_module": + assert isinstance(node.target, str) + submod = submodules[node.target] + submod_type = getattr(submod, "_base_class_origin", type(submod)) + return get_acc_ops_name(submod_type) + elif node.op == "call_function": + target: Any = node.target + return ( + f"acc_ops.{target.__name__}" + if target.__module__ is not None and "acc_ops" in target.__module__ + else _get_qualified_name(target) + ) + else: + assert isinstance(node.target, str) + return node.target + + +@compatibility(is_backward_compatible=False) +def is_node_output_tensor(node: torch.fx.Node) -> bool: + """Checks if the node output produces a Tensor or not. + + NOTE: This requires to run `ShapeProp` on the containing fx graph before + calling this function. This is because it works by checking the `type` + metadata on the node. This metadata is produced by the `ShapeProp`. + """ + type_ = node.meta.get("type", None) + return type_ is not None and issubclass(type_, torch.Tensor) + + +@compatibility(is_backward_compatible=False) +class FxNetAccFusionsFinder: + """ + Finds groups of connected ACC nodes that pass non-tensor data between each other. + Such groups are called fusion groups. + """ + + def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): + self.module = module + self.nodes = list(module.graph.nodes) + self.acc_nodes = acc_nodes + + @dataclass + class FusionGroup: + # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. + top_node_idx: int + + # Nodes in this fusion group. + nodes: NodeSet + + # Inputs to this fusion group. + inputs: NodeSet + + # Nodes that in the fusion group that haven't been processed yet. + nodes_need_process: NodeSet + + def add_node(self, node): + """ + Add a node to fusion group. + """ + if node in self.nodes: + return + + self.nodes_need_process.add(node) + self.nodes.add(node) + self.inputs.discard(node) + self.inputs.update( + { + n + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS and n not in self.nodes + } + ) + + def recursive_add_node( + self, + fusion_group: "FxNetAccFusionsFinder.FusionGroup", + inputs: Union[NodeSet, NodeList], + visited: Optional[NodeSet] = None, + ): + """ + Start from inputs and going reverse topological order. If any upstream node + is in the fusion group, add all the nodes in this path to fusion group. + """ + for arg in inputs: + # skip the node if already seen + if visited is not None: + if arg in visited: + continue + visited.add(arg) + + # Skip placeholder and get_attr because they won't be in the fusion group. + if arg.op not in CALLABLE_NODE_OPS: + continue + + # If the node has smaller idx, it's already an upstream node of the fusion + # group. We don't need to check it anymore. + if self.nodes.index(arg) < fusion_group.top_node_idx: + continue + + # If the node is in the fusion group, return True. + if arg in fusion_group.nodes: + return True + + # Check the upstream nodes of the node, if any of them is in the fusion group + # we'll add this node to fusion group and return True. + if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited): + fusion_group.add_node(arg) + return True + + return False + + def __call__(self) -> dict[torch.fx.Node, NodeSet]: + result: dict[torch.fx.Node, NodeSet] = {} + acc_nodes = list(self.acc_nodes) + + for node in acc_nodes: + if node in result: + continue + if node.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in node.meta: + continue + if node not in self.acc_nodes: + continue + + fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup( + top_node_idx=self.nodes.index(node), + nodes={node}, + inputs=set(node.all_input_nodes), + nodes_need_process={node}, + ) + while fusion_group.nodes_need_process: + node = fusion_group.nodes_need_process.pop() + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Optionally add downstream nodes + if "tensor_meta" not in node.meta: + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + if user in fusion_group.nodes: + continue + + fusion_group.add_node(user) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Add some upstream nodes + for arg in node.all_input_nodes: + if arg.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in arg.meta: + continue + if arg in fusion_group.nodes: + continue + + fusion_group.add_node(arg) + fusion_group.top_node_idx = min( + fusion_group.top_node_idx, self.nodes.index(arg) + ) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + if not (set(fusion_group.nodes) <= self.acc_nodes): + self.acc_nodes -= fusion_group.nodes + else: + for n in fusion_group.nodes: + result[n] = fusion_group.nodes + + return result + + +@compatibility(is_backward_compatible=False) +def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + Returns: + The graph module in-place sorted + """ + + # These operators are used for making runtime assertions before any + # data-dependent operators occur. We want to prioritize sorting these to + # ensure that these assertions appear before any data-dependent operations + # in the graph. + PRIORITIZED_OPS = [ + operator.add, + operator.mul, + operator.sub, + operator.floordiv, + operator.truediv, + operator.mod, + operator.le, + operator.lt, + operator.ge, + operator.gt, + operator.eq, + operator.ne, + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_async.msg, + torch.ops.aten.scalar_tensor.default, + torch.ops.aten._assert_scalar.default, + ] + + indeg = dict.fromkeys(gm.graph.nodes, 0) + new_graph = torch.fx.Graph() + # Track how many unfulfilled dependencies each node has + for node in gm.graph.nodes: + for user in node.users: + indeg[user] += 1 + queue: collections.deque = collections.deque() + # Add all nodes with no dependencies to the queue + for node in gm.graph.nodes: + if indeg[node] == 0: + queue.append(node) + env: dict[torch.fx.Node, torch.fx.Node] = {} + # Pop nodes from the queue, and add nodes that have had all their + # dependencies fulfilled + while len(queue) > 0: + cur = queue.popleft() + env[cur] = new_graph.node_copy(cur, lambda x: env[x]) + for user in cur.users: + indeg[user] -= 1 + if indeg[user] == 0: + if user.op == "call_function" and user.target in PRIORITIZED_OPS: + queue.appendleft(user) + else: + queue.append(user) + # If the new graph's size is not as large as the old one, then there must be + # a cycle (i.e. some node's dependencies were not satisfied.) + if len(new_graph.nodes) < len(gm.graph.nodes): + raise RuntimeError( + f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}" + ) + new_graph._codegen = gm.graph._codegen + gm.graph = new_graph + return gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5e7e66868a0776609ff7ffff458f6a91ccf98a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__init__.py @@ -0,0 +1 @@ +from .common import compare_graphs, HolderModule, lift_subgraph_as_module diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045b7fe0e2ab5e2ecaab0ba4d1510f4c91001555 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68bf5e220fe519ca98e55c404b99579195e5f430 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14c6ac292214c325499738f4e8390c8df72b66ba Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e778dcf5a8fb57127893cad3dbc862e613b5d328 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/common.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..17362c9eec1254305527714234846943460c45fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/common.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.nn import Module + + +__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] + + +@compatibility(is_backward_compatible=False) +class HolderModule(Module): + """ + HolderModule is used to copy all the attributes from original module to submodules + that uses the attributes + """ + + def __init__(self, d): + super().__init__() + for k, v in d.items(): + self.add_module(k, v) + + +@compatibility(is_backward_compatible=False) +def lift_subgraph_as_module( + gm: GraphModule, + subgraph: Graph, + comp_name: str = "", + class_name: str = "GraphModule", +) -> tuple[GraphModule, dict[str, str]]: + """ + Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. + + Args: + gm (GraphModule): parent graph module + + subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph + + comp_name (str): name for the new component + + class_name (str): name for the submodule + + """ + + # Loop through all module calls (call_module) and param fetches (get_attr) + # in this component, creating HolderModules as necessary to match the path. + # e.g. if in the original module there's a get_attr node fetches "conv.weight". + # We create a HolderModule as root -> add a HolderModule named "conv" -> + # make "weight" a attribute of "conv" HolderModule and point to conv.weight in + # the original module. + submodule = HolderModule({}) + orig_to_split_fqn_mapping: dict[str, str] = {} + for n in subgraph.nodes: + if n.op not in ("call_module", "get_attr"): + continue + + target = n.target + assert isinstance(target, str) + target_name_parts = target.split(".") + curr = submodule + orig_gm = gm + + for name in target_name_parts[:-1]: + if not hasattr(curr, name): + curr.add_module(name, HolderModule({})) + + curr = getattr(curr, name) + orig_gm = getattr(orig_gm, name) + + leaf_node_name = target_name_parts[-1] + leaf_node = getattr(orig_gm, leaf_node_name) + + orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" + # Relies on custom __setattr__ magic. + setattr(curr, leaf_node_name, leaf_node) + + return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping + + +@compatibility(is_backward_compatible=False) +def compare_graphs(left: Graph, right: Graph) -> bool: + """ + Return True if two graphs are identical, i.e they + - have the same number of outputs in the same order + - have the same number of inputs in the same order + - have the same set of nodes, and identical connectivity + """ + + matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) + matches = matcher.match(right) + + return len(matches) > 0 diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ad8fad1a621c7d81d12b9665b800123f1751d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/fuser_utils.py @@ -0,0 +1,275 @@ +import copy +from queue import SimpleQueue +from typing import Optional as _Optional + +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet +from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] + + +@compatibility(is_backward_compatible=False) +def topo_sort(nodes: NodeList) -> NodeList: + # sort nodes according to the topological order + indegree_map = dict.fromkeys(nodes, 0) + candidates: SimpleQueue[Node] = SimpleQueue() + + for node in nodes: + for n in node.all_input_nodes: + if n in indegree_map: + indegree_map[node] += 1 + if indegree_map[node] == 0: + candidates.put(node) + + sorted_nodes: NodeList = [] + while not candidates.empty(): + node = candidates.get() + sorted_nodes.append(node) + + for n in node.users: + if n in indegree_map: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + assert len(nodes) == len(sorted_nodes), ( + "topological sorted nodes doesn't have same length as input nodes" + ) + + return sorted_nodes + + +@compatibility(is_backward_compatible=False) +def validate_partition(partition: NodeList) -> bool: + # verify the partition does't form a dependency cycle in the original graph + # returns True for valid partition, False for invalid + + partition_set = set(partition) + + outputs: NodeList = [] + for node in partition_set: + for user_node in node.users: + if user_node not in partition_set: + # external user node, need to expose as an output + outputs.append(user_node) + + # Perform BFS on the partition outputs. + # If it reaches a node within the partition, then it found a cycle. + # This function takes the ownership of `root_nodes` and may modify it. + def bfs_find_cycle(root_nodes: NodeList) -> bool: + # Set used to exclude nodes that have already been visited. + # If a node has been visited, that node and all its children have + # been checked for cycles. + visited: NodeSet = set() + + # Start with `root_nodes` and traverse through (toward child nodes) + # their connected sub-graph. Nodes in `visited` won't be added + # to `queue` again. + queue: NodeList = root_nodes + while queue: + current = queue.pop() + visited.add(current) + if current in partition_set: + # Started from partition's `output` nodes, and reached + # another node in partition. Cycle! + return True + for user_node in current.users: + if user_node in visited: + continue + queue.append(user_node) + # `root_nodes` don't cause cycle. + return False + + # Use all output nodes as roots to traverse + # the graph to check cycles. + if bfs_find_cycle(outputs): + return False + + return True + + +@compatibility(is_backward_compatible=False) +def fuse_as_graphmodule( + gm: GraphModule, + nodes: NodeList, + module_name: str, + partition_lookup_table: _Optional[dict[Node, None]] = None, + *, + always_return_tuple: bool = False, +) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: + """ + Fuse nodes in graph_module into a GraphModule. + + Args: + gm (GraphModule): target graph_module + + nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted + + module_name: class name for the fused GraphModule + + partition_lookup_table (Optional[Dict[Node, None]]): optional dict of nodes to speed up lookup + + always_return_tuple (bool): whether to always return a tuple, even if there is only one output + + Returns: + fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` + + original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` + + original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` + + """ + + # assumption: nodes are already sorted in topo order + + for node in nodes: + assert node.graph.owning_module is gm, ( + f"{node} doesn't belong to passed in graph module {gm._get_name()}" + ) + assert not node._erased, f"{node} has been removed from owning graph" + assert node in gm.graph._find_nodes_lookup_table, ( + f"{node} is not found in graph module {gm._get_name()}" + ) + + # validates partition doesn't introduce dependency circles in the graph + assert validate_partition(nodes), "Invalid partition, found dependency cycles" + + # if no dict of partition nodes is provided, reconstruct it by nodes list to reduce lookup time + if partition_lookup_table is None: + partition_lookup_table = dict.fromkeys(nodes) + + subgraph = Graph() + + node_to_placeholder: dict[ + Node, Node + ] = {} # mapping of nodes from old graph to placeholder in new graph + node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph + + # handles inputs through graph.node_copy's arg_transform functions + def remap_inputs(x: Node) -> Node: + if x.op == "get_attr": + # TODO: do we really need copy the get_attr node into the graph? + # do something here + pass + + if x in partition_lookup_table: + # x is inside subgraph, return the copied node + # the node should have been copied aleady, as we are copying graph in the topological order + return node_map[x] + + if x not in node_to_placeholder: + # x is not in subgraph, create a new placeholder for subgraph + placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelvant for the placeholder node + placeholder_node.meta = copy.copy(x.meta) + node_to_placeholder[x] = placeholder_node + + return node_to_placeholder[x] + + # copy nodes in topological order + for node in nodes: + new_node = subgraph.node_copy(node, remap_inputs) + node_map[node] = new_node + + # handles outputs + output_mapping: dict[Node, Node] = {} # mapping from old output to new outputs + + for node in nodes: + for user_node in node.users: + if user_node not in partition_lookup_table: + # external user node, need to expose as an output + output_mapping[node] = node_map[node] + + # outs contain nodes in the new subgraph + outs = tuple(output_mapping.values()) + + if always_return_tuple: + # always return a tuple, even if there is only one output + subgraph.output(outs) + else: + # If there's a single output then return it directly, otherwise return a tuple. + subgraph.output(outs[0] if len(outs) == 1 else outs) + + # lint to ensure correctness + subgraph.lint() # type: ignore[no-untyped-call] + fused_gm: GraphModule + fused_gm, _ = lift_subgraph_as_module( + gm, subgraph, comp_name="", class_name=module_name + ) + + # sub_gm's input nodes in the original module + original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys()) + + # sub_gm's outputs node in the original module + original_outputs: tuple[Node, ...] = tuple(output_mapping.keys()) + + return fused_gm, original_inputs, original_outputs + + +@compatibility(is_backward_compatible=False) +def insert_subgm( + gm: GraphModule, + sub_gm: GraphModule, + orig_inputs: tuple[Node, ...], + orig_outputs: tuple[Node, ...], +) -> GraphModule: + # add sub_gm into gm + submodule_name = sub_gm.__class__.__name__ + gm.add_submodule(submodule_name, sub_gm) + + # Create a call_module node in main graph. + module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None) + + output_node = sub_gm.graph.output_node() + if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple): + # main_remapping[comp.orig_outputs[0]] = module_node + orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True) + else: + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + + module_node.meta["val"] = tuple( + orig_output.meta.get("val", None) for orig_output in orig_outputs + ) + return gm + + +@compatibility(is_backward_compatible=False) +def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: + # erase original nodes in inversed topological order + for node in reversed(nodes): + gm.graph.erase_node(node) + + +@compatibility(is_backward_compatible=False) +def fuse_by_partitions( + gm: GraphModule, + partitions: list[dict[Node, None]], + prefix: str = "fused_", + always_return_tuple: bool = False, +) -> GraphModule: + for partition_id, partition in enumerate(partitions): + sorted_nodes = topo_sort(list(partition)) + + submodule_name = prefix + str(partition_id) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, + sorted_nodes, + submodule_name, + partition, + always_return_tuple=always_return_tuple, + ) + + insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) + + erase_nodes(gm, sorted_nodes) + + # topological sort original gm with newly created sub_gm + legalize_graph(gm) + + return gm diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f63935875d61ed21f5f50746f7ceb51bda65334 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_utils.py @@ -0,0 +1,440 @@ +# mypy: allow-untyped-defs +import copy +import logging +import os +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Union + +import torch +from torch.fx import Graph, Node +from torch.fx._compatibility import compatibility + + +__all__ = ["SubgraphMatcher", "InternalMatch"] + + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger(): + logger = logging.getLogger(__name__) + + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class InternalMatch: + # Nodes from which the match was found + anchors: list[Node] + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: dict[Node, Node] = field(default_factory=dict) + + # nodes in target graph that are matched placeholder in pattern + placeholder_nodes: list[Node] = field(default_factory=list) + + # nodes in matched subgraph returned by output + returning_nodes: list[Node] = field(default_factory=list) + + # map from a string name to a node in the target graph + # only available if the matcher is `SubgraphMatcherWithNameNodesMap` + name_node_map: dict[str, Node] = field(default_factory=dict) + + def __copy__(self): + return InternalMatch( + anchors=self.anchors, + nodes_map=self.nodes_map.copy(), + placeholder_nodes=self.placeholder_nodes.copy(), + returning_nodes=self.returning_nodes.copy(), + ) + + +@compatibility(is_backward_compatible=False) +class SubgraphMatcher: + def __init__( + self, + pattern: Graph, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False, + ) -> None: + """ + Args: + pattern: the targeted matching pattern, represented in fx.Graph. + match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. + If False, output node is ignored during match. + match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of + the targeted pattern. If False, placeholder nodes will be used a wildcard. + remove_overlapping_matches: If True, in the case of overlapping matches, only the first match + will be returned. + ignore_literals: If True, will not check if literals are equal and + will instead treat them as wildcards. + """ + + self.pattern = pattern + self.match_output = match_output + self.match_placeholder = match_placeholder + self.remove_overlapping_matches = remove_overlapping_matches + self.ignore_literals = ignore_literals + + if len(pattern.nodes) == 0: + raise ValueError( + "SubgraphMatcher cannot be initialized with an empty pattern" + ) + + for node in pattern.nodes: + if node.op != "output": + assert len(node.users) > 0, ( + "SubgraphMatcher cannot be initialized with an pattern with dead code" + ) + + # TODO: assert pattern is a connected graph + + self.pattern_placeholder_nodes = [ + n for n in pattern.nodes if n.op == "placeholder" + ] + output_node = next(iter(reversed(pattern.nodes))) + # nodes returned by outputs + self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes + + self.pattern_anchors: list[Node] = [] + if match_output: + self.pattern_anchors = [output_node] + else: + # If a node has output_node as the ONLY user, then this node is a graph sink, + # and should be matched against as an anchor + self.pattern_anchors = [ + n for n in output_node.all_input_nodes if len(n.users) == 1 + ] + + def _match_attributes(self, pn: Node, gn: Node) -> bool: + # Attributes matching is complicated. Right now we only support matching constant tensor + assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string." + assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string." + + pn_value = torch.fx.graph_module._get_attr(pn.graph.owning_module, pn.target) + gn_value = torch.fx.graph_module._get_attr(gn.graph.owning_module, gn.target) + + if type(pn_value) != type(gn_value): + return False + + # Don't require exact match on tensor values. + if isinstance(pn_value, torch.Tensor): + return isinstance(gn_value, torch.Tensor) + else: + raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") + return False + + def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: + # if exact match for placeholder is not required, then use placeholder as a wildcard + if not self.match_placeholder and pn.op == "placeholder": + return True + + if pn.op == gn.op: + if pn.op == "placeholder" or pn.op == "output": + return True + elif pn.op == "get_attr": + return self._match_attributes(pn, gn) + return pn.target == gn.target + return False + + def _is_contained(self, nodes_map: dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + + # Placeholders can be used by other nodes in the graphs + lookup: dict[Node, Node] = { + gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder" + } + + for gn, pn in lookup.items(): + # nodes returned by output are allowed to be used in other areas of the graph + if pn in self.pattern_returning_nodes: + continue + + for user in gn.users: + # If this node has users that were not in `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True + + def _remove_overlapping_matches( + self, matches: list[InternalMatch] + ) -> list[InternalMatch]: + non_overlapping_matches: list[InternalMatch] = [] + nodes_matched: set[Node] = set() + + for match in matches: + found_overlap = False + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"} and gn in nodes_matched: + found_overlap = True + break + + if not found_overlap: + non_overlapping_matches.append(match) + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"}: + nodes_matched.add(gn) + return non_overlapping_matches + + def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: + assert not (isinstance(pn, Node) and isinstance(gn, Node)), ( + "pn and gn cannot both be Node" + ) + + if isinstance(pn, Node) and not isinstance(gn, Node): + if pn.op == "placeholder": + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + match.nodes_map[pn] = gn + return True + else: + return False + elif not isinstance(pn, Node) and isinstance(gn, Node): + return False + else: + return type(gn) == type(pn) and gn == pn + + def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: + logger.info(" matching %s to %s", pn, gn) + + assert isinstance(pn, Node) and isinstance(gn, Node), str( + f"pn and gn must be Node, pn: {pn}, gn: {gn}" + ) + + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + # TODO: use a more efficient way to check if gn is matched before: two-way dict + if gn in match.nodes_map.values(): + return False + + if not self._nodes_are_equal(pn, gn): + return False + + # Optimistically mark `pn` as a match for `gn`, and save a local copy of match + saved_match = copy.copy(match) + match.nodes_map[pn] = gn + + # Placeholder is a wildcard and can be matched with any python object + # (including list/tuple) + if pn.op == "placeholder": + return True + + # Recursively traverse upwards to check if `pn` is a true + # match for `gn` + match_found = True + + def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool: + if len(args1) != len(args2): + return False + + for a1, a2 in zip(args1, args2): + if isinstance(a1, Node) and isinstance(a2, Node): + matched = self._match_nodes(a1, a2, match) + elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)): + matched = _match_args(a1, a2) + else: + matched = ( + self._match_literals(a1, a2, match) or self.ignore_literals + ) + + if not matched: + return False + + return True + + # Flatten all args/kwargs into 1 list of args + pn_args, gn_args = None, None + if ( + ( + len(pn.args) != len(gn.args) + or list(pn.kwargs.keys()) != list(gn.kwargs.keys()) + ) + and pn.op == "call_function" + and isinstance(pn.target, torch._ops.OpOverload) + ): + args_schema = pn.target._schema.arguments + + def get_all_arguments(orig_args, orig_kwargs): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + pn_args = get_all_arguments(pn.args, pn.kwargs) + gn_args = get_all_arguments(gn.args, gn.kwargs) + + elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list( + gn.kwargs.keys() + ): + pn_args = list(pn.args) + gn_args = list(gn.args) + pn_args.extend(list(pn.kwargs.values())) + gn_args.extend(list(gn.kwargs.values())) + else: + match_found = False + + match_found = ( + match_found + and pn_args is not None + and gn_args is not None + and _match_args(pn_args, gn_args) + ) + + if not match_found: + # revert to saved_match before matching with current node + match = copy.copy(saved_match) + return False + + return True + + def match(self, graph: Graph) -> list[InternalMatch]: + """ + Returns: + The matched subgraphs. + Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder + and nodes returned by output) can only be consumed by nodes within the matched subgraph. + + Subgraph pattern matcher is implemented with the backtracking style in the following steps: + + 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes + are the "sinks" (nodes with no user other than the output node) of the pattern graph. + One pattern graph could have multiple anchors if it has multiple return values. + + 2. In the target graph, we identify the potential candidate nodes that can be matched + with each anchor. These anchor-candidate pairs are the starting points for + pairwise per-node matching. + + 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both + pattern and target graphs. For every pattern nodes along traversal path, we compare it + against the target nodes. In case any comparison failed, the match for this anchor-candidate + pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` + for more details. + + 4. In the case of multiple anchors, every anchor will need to find a match using step 3. + In addition, the matches found between anchors need to have a common intersection node + in order for the match to be valid. This is implemented with backtracking. See `backtracking` + for more details. + + Notice: graph traversal must be done in the reverser order because a tensor can have multiple + consumers, but can only have a single producer. Only with reverser order, we can we jointly + traverse the pattern and target graph in a deterministic path. + + Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, + in practice, it's unlikely to blow up. + + """ + from torch.fx.passes.utils.fuser_utils import validate_partition + + # find candidate nodes to match with pattern anchors + match_candidates: dict[Node, list[Node]] = defaultdict(list) + for pattern_anchor in self.pattern_anchors: + for node in graph.nodes: + if self._nodes_are_equal(pattern_anchor, node): + match_candidates[pattern_anchor].append(node) + match_candidates_list = list(match_candidates.items()) + + logger.info("Initial match_candidates_list: %s\n", match_candidates_list) + + matches: list[InternalMatch] = [] + + def backtracking(anchor_index, match): + if anchor_index == len(match_candidates_list): + match.placeholder_nodes = [ + match.nodes_map[pn] for pn in self.pattern_placeholder_nodes + ] + match.returning_nodes = [ + match.nodes_map[pn] for pn in self.pattern_returning_nodes + ] + matches.append(match) + + logger.info("Found a match: %s\n", match) + return + + pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] + saved_match = copy.copy(match) + + for node in candidate_nodes: + logger.info("Trying to match anchor %s to %s", pattern_anchor, node) + + match_found = self._match_nodes(pattern_anchor, node, match) + if match_found: + # match next anchor + backtracking(anchor_index + 1, match) + else: + logger.info( + "Failed to match anchor %s to %s\n", pattern_anchor, node + ) + + # revert to saved_match before matching with current anchor + match = copy.copy(saved_match) + + match = InternalMatch(anchors=self.pattern_anchors) + if match_candidates_list: + backtracking(0, match) + + # filter out the matches where the subgraph is not fully_contained + before = len(matches) + matches = [match for match in matches if self._is_contained(match.nodes_map)] + after = len(matches) + if before != after: + logger.info( + "Filtered out %s matches because they are not fully contained", + before - after, + ) + + # filter out the matches that form a cycle if the subgraph is fused + valid_matches = [] + for match in matches: + matched_compute_nodes = [ + gn + for pn, gn in match.nodes_map.items() + if pn.op not in {"placeholder", "output"} + ] + if validate_partition(matched_compute_nodes): + valid_matches.append(match) + if len(valid_matches) != len(matches): + logger.info( + "Filtered out %s matches because \ + matched subgraph would form a cycle if fused", + len(matches) - len(valid_matches), + ) + + if self.remove_overlapping_matches: + before = len(valid_matches) + matches = self._remove_overlapping_matches(valid_matches) + after = len(matches) + if before != after: + logger.info( + "Filtered out %s matches because matched subgraphs are overlapping", + before - after, + ) + + logger.info("Matches returned: %s", matches) + + return matches diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..091ec7f1f82b234c33307ffe0f8c972911fe6d47 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -0,0 +1,114 @@ +from torch.fx import Graph, GraphModule, Node +from torch.fx._compatibility import compatibility + +from .matcher_utils import InternalMatch, SubgraphMatcher + + +__all__ = ["SubgraphMatcherWithNameNodeMap"] + + +def _split_to_graph_and_name_node_map( + gm: GraphModule, +) -> tuple[GraphModule, dict[str, Node]]: + from torch.fx.graph import _PyTreeInfo + from torch.utils._pytree import tree_flatten, tree_unflatten + + name_node_map = {} + for n in gm.graph.nodes: + if n.op == "output": + assert gm._out_spec is not None + output = tree_unflatten(n.args[0], gm._out_spec) + assert isinstance(output, tuple), ( + "Expecting the pattern graph to return a tuple" + ) + assert len(output) >= 2, ( + "Expecting the pattern graph to have at least two outputs" + ) + *out, name_node_map = output + flattened, out_spec = tree_flatten(out) + assert isinstance(name_node_map, dict), ( + "Expecting the input graph to have a dict output as the last element" + ) + n.args = (flattened,) + orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] + gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] + orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec + ) + gm.recompile() + return gm, name_node_map + + +@compatibility(is_backward_compatible=False) +class SubgraphMatcherWithNameNodeMap(SubgraphMatcher): + """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name, + this requires pattern to have specific format (returning and additional dictionary at the output, + that has node name as key, and the node in the pattern graph as value, see Example for more details) + + Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during + initialization since we need to modify the graph (which requires `recompile` the GraphModule) + + Example:: + def pattern(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + return relu, {"conv": conv, "relu": relu} + + + def target_graph(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + relu *= 2 + return relu + + + pattern_gm = export_for_training(pattern, example_inputs).module() + target_gm = export_for_training(target_graph, example_inputs).module() + matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) + matches = matcher.match(target_gm) + for match in matches: + match.name_node_map["conv"].meta["annotation"] = ... + + """ + + def __init__( + self, + pattern_gm: GraphModule, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False, + ) -> None: + pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) + self.name_node_map = name_node_map + super().__init__( + pattern_gm.graph, + match_output, + match_placeholder, + remove_overlapping_matches, + ignore_literals, + ) + + def match(self, graph: Graph) -> list[InternalMatch]: + """The returned InternalMatch will have name_node_map populated with a map + from node name (str) to the target node, e.g. + {"conv": target_conv_ndoe, "relu": target_relu_node} + + this requires the pattern graph returns an additional + output of node name to node, e.g. instead of: + ``` + def pattern(...): + ... + return relu + ``` + we should do: + ``` + def pattern(...): + ... + return relu, {"conv": conv, "relu": relu} + ``` instead + """ + internal_matches = super().match(graph) + for internal_match in internal_matches: + for k, n in self.name_node_map.items(): + internal_match.name_node_map[k] = internal_match.nodes_map[n] + return internal_matches diff --git a/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..97a60b06694c0ba9b03651ecf422574895ba8283 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/fx/passes/utils/source_matcher_utils.py @@ -0,0 +1,162 @@ +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.node import Node + + +__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"] + + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class SourcePartition: + # Nodes in a particular partition + nodes: list[Node] + + # The source these nodes decomposed from + source: Any + + # Nodes in the graph that are needed as inputs to the partition + # These do not include the params of the partition + input_nodes: list[Node] = field(default_factory=list) + + # Nodes in the partition that are being used by nodes outside of the + # partition + output_nodes: list[Node] = field(default_factory=list) + + # Parameters that are being used + params: list[Node] = field(default_factory=list) + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def get_source_partitions( + graph: Graph, + wanted_sources: list[Any], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> dict[Any, list[SourcePartition]]: + """ + Args: + graph: The graph we want to partition + wanted_sources: List of sources of nodes that were decomposed from this + source. This can be a function (ex. torch.nn.functional.linear) or a + leaf module type (ex. torch.nn.Linear). + + Returns: + Dictionary mapping sources that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + source. + """ + modules: dict[type, dict[str, list[Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + + # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can + # be different from "source_fn_stack", for example for the add_ node + # decomposed from batch norm. We should remove the check on "source_fn_stack" + # after we fix "torch_fn". T199561090 + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and ( + torch_fn := node.meta.get("torch_fn", None) + ) is not None: + node_fqn, source_fn = torch_fn + source_fn_name = source_fn.split(".")[1] + if source_fn_name in wanted_sources: + diff_modules = modules.setdefault(source_fn_name, {}) + partition = diff_modules.setdefault(node_fqn, []) + partition.append(node) + + if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None: + source_fn = source_fn_st[-1] + if source_fn[1] in wanted_sources: + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(source_fn[0], []) + partition.append(node) + + def make_partition(nodes: list[Node], module_type: type) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, Node) and arg not in nodes and arg.op != "get_attr": + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + # get_attr nodes won't be output nodes + continue + + for user in node.users.keys(): + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: dict[type[Any], list[SourcePartition]] = {} + + if filter_fn: + # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the + # filter condition + filtered_modules = {} + for tp, name_to_partition in modules.items(): + filtered_name_to_partition = { + name: partition + for name, partition in name_to_partition.items() + if all(map(filter_fn, partition)) + } + filtered_modules[tp] = filtered_name_to_partition + modules = filtered_modules + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def check_subgraphs_connected( + subgraph1: SourcePartition, subgraph2: SourcePartition +) -> bool: + """ + Given two subgraphs A and B (in the form of a list of nodes), checks if + A has nodes connecting to at least one node in B -- aka there exists a node + in B that uses a node in A (not the other way around). + """ + + for node in reversed(subgraph1.nodes): + for user in node.users.keys(): + if user in subgraph2.nodes: + return True + return False diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/all.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/all.h new file mode 100644 index 0000000000000000000000000000000000000000..026f4f9f579e906893636fed0f4e84bcf5440e3f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/all.h @@ -0,0 +1,23 @@ +#pragma once + +#if !defined(_MSC_VER) && __cplusplus < 201703L +#error C++17 or later compatible compiler is required to use PyTorch. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/arg.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/arg.h new file mode 100644 index 0000000000000000000000000000000000000000..9af29d5446aef33ed1cae046888423009494d753 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/arg.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +#define TORCH_ARG(T, name) \ + public: \ + inline auto name(const T& new_##name) -> decltype(*this) { /* NOLINT */ \ + this->name##_ = new_##name; \ + return *this; \ + } \ + inline auto name(T&& new_##name) -> decltype(*this) { /* NOLINT */ \ + this->name##_ = std::move(new_##name); \ + return *this; \ + } \ + inline const T& name() const noexcept { /* NOLINT */ \ + return this->name##_; \ + } \ + inline T& name() noexcept { /* NOLINT */ \ + return this->name##_; \ + } \ + \ + private: \ + T name##_ /* NOLINT */ diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/autograd.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/autograd.h new file mode 100644 index 0000000000000000000000000000000000000000..cf0608fa01bbf5549d81c6edc1b3e4cd82de379b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/autograd.h @@ -0,0 +1,5 @@ +#pragma once + +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..1eecfe9b5ddf9c8538113b09c8de01d1e7099cc7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/cuda.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +#include + +namespace torch::cuda { + +/// Returns the number of CUDA devices available. +c10::DeviceIndex TORCH_API device_count(); + +/// Returns true if at least one CUDA device is available. +bool TORCH_API is_available(); + +/// Returns true if CUDA is available, and CuDNN is available. +bool TORCH_API cudnn_is_available(); + +/// Sets the seed for the current GPU. +void TORCH_API manual_seed(uint64_t seed); + +/// Sets the seed for all available GPUs. +void TORCH_API manual_seed_all(uint64_t seed); + +/// Waits for all kernels in all streams on a CUDA device to complete. +void TORCH_API synchronize(int64_t device_index = -1); + +} // namespace torch::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data.h new file mode 100644 index 0000000000000000000000000000000000000000..78aae1d25c27cc87a9829efa60b69b69c5de5de4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include +#include + +// Some "exports". + +namespace torch::data { +using datasets::BatchDataset; // NOLINT +using datasets::Dataset; // NOLINT +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h new file mode 100644 index 0000000000000000000000000000000000000000..c60abc79c847efead69f425c221fda26debcd4eb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include + +#include + +#include +#include +#include +#include + +namespace torch::data { + +/// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and +/// some `options`. +template +std::enable_if_t< + !Dataset::is_stateful, + std::unique_ptr>> +make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { + return std::make_unique>( + std::move(dataset), std::move(sampler), options); +} + +/// Creates a `DataLoader` instance for a stateless `dataset` and some +/// `options`. A sampler (by default a `RandomSampler`) will be constructed from +/// the size of the dataset. +template +std::enable_if_t< + !Dataset::is_stateful && std::is_constructible_v, + std::unique_ptr>> +make_data_loader( + Dataset dataset, + DataLoaderOptions options = DataLoaderOptions()) { + const std::optional size = dataset.size(); + TORCH_CHECK( + size.has_value(), + "Expected the dataset to be sized in " + "order to construct the Sampler"); + return make_data_loader(std::move(dataset), Sampler(*size), options); +} + +/// Creates a `DataLoader` for a stateful `dataset` and some `options`. +template > +std::unique_ptr> make_data_loader( + Dataset dataset, + DataLoaderOptions options = DataLoaderOptions()) { + return std::make_unique>( + std::move(dataset), options); +} +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h new file mode 100644 index 0000000000000000000000000000000000000000..35901ff991e938352f0315dc09a7e42890db9582 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h @@ -0,0 +1,254 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace torch::data { +template +class DataLoaderBase { + public: + using BatchType = Batch; + using BatchRequestType = BatchRequest; + + /// Constructs a new DataLoader from a `dataset` to sample from, `options` + /// to configure the DataLoader with, and a `sampler` that specifies the + /// sampling strategy. + DataLoaderBase( + DataLoaderOptions options, + std::unique_ptr main_thread_dataset = nullptr) + : options_(options), + main_thread_dataset_(std::move(main_thread_dataset)), + sequencer_(new_sequencer()) {} + + DataLoaderBase(const DataLoaderBase&) = delete; + DataLoaderBase(DataLoaderBase&&) = delete; + DataLoaderBase& operator=(const DataLoaderBase&) = delete; + DataLoaderBase& operator=(DataLoaderBase&&) = delete; + // NOLINTNEXTLINE(bugprone-exception-escape) + virtual ~DataLoaderBase() { + join(); + } + + /// Returns an iterator into the DataLoader. The lifetime of the iterator is + /// bound to the DataLoader. In C++ standards language, the category of the + /// iterator is `OutputIterator`. See + /// https://en.cppreference.com/w/cpp/named_req/OutputIterator for what this + /// means. In short: you may increment the iterator and dereference it, but + /// cannot go back, or step forward more than one position at a time. When the + /// DataLoader is exhausted, it will compare equal with the special + /// "sentinel" iterator returned by `DataLoader::end()`. Most of the time, you + /// should only use range-for loops to loop over the DataLoader, but + /// standard algorithms like `std::copy(dataloader.begin(), dataloader.end(), + /// output_iterator)` are supported too. + Iterator begin() { + TORCH_CHECK( + shuttle_.in_flight_jobs() == 0, + "Attempted to get a new DataLoader iterator " + "while another iterator is not yet exhausted"); + reset(); + return Iterator(std::make_unique>( + [this] { return this->next(); })); + } + + /// Returns a special "sentinel" iterator that compares equal with a + /// non-sentinel iterator once the DataLoader is exhausted. + Iterator end() { + return Iterator(std::make_unique>()); + } + + /// Joins the DataLoader's worker threads and drains internal queues. + /// This function may only be invoked from the main thread (in which the + /// DataLoader lives). + void join() { + if (joined_) { + return; + } + shuttle_.drain(); + // Send one 'quit' message per worker. Since a worker dies (exits its + // thread) after receiving this message, each `QuitWorker()` message will be + // read by exactly one worker. + for ([[maybe_unused]] const auto w : c10::irange(options_.workers)) { + push_job(QuitWorker()); + } + for (auto& worker : workers_) { + worker.join(); + } + joined_ = true; + } + + /// Returns the options with which the DataLoader was configured. + const FullDataLoaderOptions& options() const noexcept { + return options_; + } + + protected: + /// Simple mix-in to give something a sequence number. + struct Sequenced { + Sequenced() = default; + Sequenced(size_t sqn) : sequence_number(sqn) {} + size_t sequence_number; + }; + + struct QuitWorker {}; + + /// A `Job` is either a `BatchRequest` (new indices to fetch data at) or a + /// `QuitWorker` object, to indicate the worker should shut down. + struct Job : Sequenced { + Job() = default; + Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {} + Job(BatchRequest&& i, size_t sqn) + : Sequenced(sqn), batch_request(std::move(i)) {} + std::optional quit; + std::optional batch_request; + }; + + /// The finished result of a job. + struct Result : Sequenced { + Result() = default; + Result(std::optional&& b, size_t sqn) + : Sequenced(sqn), batch(std::move(b)) {} + Result(std::exception_ptr exception, size_t sqn) + : Sequenced(sqn), exception(std::move(exception)) {} + std::optional batch; + std::exception_ptr exception; + }; + + /// Subclass hook for getting the next batch request. The stateless case will + /// ask the sampler for a new batch request (e.g. a vector of indices), while + /// the stateful one will simply return the batch size. + virtual std::optional get_batch_request() = 0; + + /// Resets the internal state of the DataLoader, optionally pre-fetching + /// new jobs. + virtual void reset() { + shuttle_.drain(); + sequence_number_ = 0; + sequencer_ = new_sequencer(); + prefetch(); + } + + /// Schedules `requested_jobs` many new batches to be fetched. The actual + /// number of jobs scheduled may be less if the DataLoader exhausts. + void prefetch(size_t requested_jobs) { + for ([[maybe_unused]] const auto r : c10::irange(requested_jobs)) { + if (auto batch_request = get_batch_request()) { + this->push_job(std::move(*batch_request)); + } else { + break; + } + } + } + + /// Schedules the maximum number of jobs (based on the `max_jobs` option). + void prefetch() { + prefetch(options_.max_jobs); + } + + /// Returns the next batch of data, or an empty `optional` if the DataLoader + /// is exhausted. This operation will block until a batch is available if one + /// is still expected. + std::optional next() { + if (options_.workers > 0) { + while (std::optional result = this->pop_result()) { + if (result->exception) { + throw WorkerException(result->exception); + } else if (result->batch) { + prefetch(1); + return std::move(result->batch); + } + } + } else if (auto batch_request = get_batch_request()) { + return this->main_thread_dataset_->get_batch(std::move(*batch_request)); + } + return std::nullopt; + } + + /// The function that worker threads run. + void worker_thread(Dataset& dataset) { + while (true) { + auto job = shuttle_.pop_job(); + if (job.quit) { + break; + } + try { + auto batch = dataset.get_batch(std::move(*job.batch_request)); + shuttle_.push_result({std::move(batch), job.sequence_number}); + } catch (...) { + shuttle_.push_result({std::current_exception(), job.sequence_number}); + } + } + } + + /// Convenience method that calls `shuttle_.push_job()` with the next sequence + /// number. + template + void push_job(T value) { + shuttle_.push_job({std::move(value), sequence_number_++}); + } + + /// Convenience method that gets the next result from the sequencer. + std::optional pop_result() { + return sequencer_->next( + [this] { return this->shuttle_.pop_result(this->options_.timeout); }); + } + + /// Convenience method that creates a new sequencer based on the + /// `enforce_ordering` option. + std::unique_ptr> new_sequencer() { + if (options_.enforce_ordering) { + return std::make_unique>( + options_.max_jobs); + } + return std::make_unique>(); + } + + /// The options the DataLoader was configured with. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const FullDataLoaderOptions options_; + + /// The dataset for the main thread, only has a value if the number of + /// worker threads was configured as zero, meaning the main thread has to do + /// all the work (synchronously). NOTE: Really want this to be on the heap + /// when empty, therefore `unique_ptr` and not `optional`. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::unique_ptr main_thread_dataset_; + + /// The sequence number for the *next* batch to be retrieved from the + /// dataset. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t sequence_number_ = 0; + + /// The worker threads, running the `worker_thread()` method. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector workers_; + + /// The `DataShuttle` which takes care of the life cycle of a job. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + detail::DataShuttle shuttle_; + + /// The `Sequencer`, which handles optional ordering of batches. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::unique_ptr> sequencer_; + + /// True if the DataLoader has joined its worker threads. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool joined_ = false; +}; +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h new file mode 100644 index 0000000000000000000000000000000000000000..964a1ffcc7f6ca7bad92bb142aaec54afb437ecb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::data { + +/// A dataloader for stateful datasets. +/// +/// A dataloader for stateful datatasets differs from one for stateless +/// datasets one in that the dataset is shared among worker threads, and that +/// this dataset is itself responsible for producing batches rather than +/// depending on a sampler. The statefulness here actually refers to the +/// dataset. The StatefulDataLoader simply alters the data loading algorithm to +/// accommodate the stateful, shared nature of the dataset. Note that the +/// dataset must be thread safe if more than one worker thread is used. +/// +/// A stateful dataloader is created by calling `make_data_loader` with a +/// stateful dataset. +template +class StatefulDataLoader : public DataLoaderBase< + Dataset, + typename Dataset::BatchType::value_type, + typename Dataset::BatchRequestType> { + public: + using super = DataLoaderBase< + Dataset, + typename Dataset::BatchType::value_type, + typename Dataset::BatchRequestType>; + using typename super::BatchRequestType; + + /// Constructs the `StatefulDataLoader` from a `dataset` and some `options`. + StatefulDataLoader(Dataset dataset, DataLoaderOptions options) + : super(options, std::make_unique(std::move(dataset))) { + for ([[maybe_unused]] const auto _ : c10::irange(this->options_.workers)) { + // As opposed to the stateless case, here all worker threads access the + // same underlying dataset. + this->workers_.emplace_back( + [this] { this->worker_thread(*this->main_thread_dataset_); }); + } + } + + private: + /// Resets the internal state of the dataloader and the dataset. + void reset() override { + this->main_thread_dataset_->reset(); + // Call the base class method last because it calls `prefetch()` + super::reset(); + } + + /// For stateful datasets, the batch request is always the batch size. The + /// dataset is responsible for determining what goes into the batch next. + std::optional get_batch_request() override { + return this->options_.batch_size; + } +}; +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h new file mode 100644 index 0000000000000000000000000000000000000000..07bf330205442450e848a249327560f9c1687ae8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace torch::data { + +/// A dataloader for stateless datasets. +/// +/// This dataloader follows the traditional PyTorch dataloader design, whereby a +/// (possibly) stateful sampler produces *batch requests* for a stateless +/// dataset, which acts as a simple batch request to batch mapping. The batch +/// request will often be an array of indices, and if the dataset is a simple +/// image dataset, the dataset would produce the images at those indices. +template +class StatelessDataLoader : public DataLoaderBase< + Dataset, + typename Dataset::BatchType, + typename Sampler::BatchRequestType> { + public: + using super = DataLoaderBase< + Dataset, + typename Dataset::BatchType, + typename Sampler::BatchRequestType>; + using typename super::BatchRequestType; + + /// Constructs the `StatelessDataLoader` from a `dataset`, a `sampler` and + /// some `options`. + StatelessDataLoader( + Dataset dataset, + Sampler sampler, + DataLoaderOptions options) + : super(options), sampler_(std::move(sampler)) { + for (const auto w : c10::irange(this->options_.workers)) { + // Here we copy the dataset into the worker thread closure. Each worker + // has its own copy of the dataset. This means the dataset must be + // trivially copiable, or else we don't expect more than one worker to + // be in use. + (void)w; // Suppress unused variable warning + this->workers_.emplace_back( + [this, dataset]() mutable { this->worker_thread(dataset); }); + } + if (this->options_.workers == 0) { + this->main_thread_dataset_ = + std::make_unique(std::move(dataset)); + } + } + + private: + /// Resets the internal state of the dataloader and the sampler. + void reset() override { + sampler_.reset(); + // Call the base class method last because it calls `prefetch()` + super::reset(); + } + + /// Queries the sampler for the next batch request (possibly progressing its + /// internal state). + std::optional get_batch_request() override { + auto indices = sampler_.next(this->options_.batch_size); + if (!indices || + (indices->size() < this->options_.batch_size && + this->options_.drop_last)) { + return std::nullopt; + } + AT_ASSERT(indices->size() > 0); + return indices; + } + + /// The `Sampler` used to produce batch requests. + Sampler sampler_; +}; +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h new file mode 100644 index 0000000000000000000000000000000000000000..34dd3a00dc47a0640b41840c317022fbd9f9b56c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch::data { + +/// Options to configure a `DataLoader`. +struct DataLoaderOptions { + DataLoaderOptions() = default; + /* implicit */ DataLoaderOptions(size_t batch_size) + : batch_size_(batch_size) {} + + /// The size of each batch to fetch. + TORCH_ARG(size_t, batch_size) = 1; + + /// The number of worker threads to launch. If zero, the main thread will + /// synchronously perform the data loading. + TORCH_ARG(size_t, workers) = 0; + + /// The maximum number of jobs to enqueue for fetching by worker threads. + /// Defaults to two times the number of worker threads. + TORCH_ARG(std::optional, max_jobs); + + /// An optional limit on the time to wait for the next batch. + TORCH_ARG(std::optional, timeout); + + /// Whether to enforce ordering of batches when multiple are loaded + /// asynchronously by worker threads. Set to `false` for better performance if + /// you do not care about determinism. + TORCH_ARG(bool, enforce_ordering) = true; + + /// Whether to omit the last batch if it contains less than `batch_size` + /// examples. + TORCH_ARG(bool, drop_last) = false; +}; + +/// Like `DataLoaderOptions`, but without any unconfigured state. +/// `DataLoaderOptions` has some options that depend on other options +/// (`max_jobs` => `2 * workers`). In the spirit of properly using the C++ type +/// system, `DataLoaderOptions` allows only setting values. To access values, +/// you must create a `FullDataLoaderOptions` from a `DataLoaderOptions` +/// instance, which will do any necessary coalescing. +struct FullDataLoaderOptions { + explicit FullDataLoaderOptions(DataLoaderOptions options) + : batch_size(options.batch_size()), + workers(options.workers()), + max_jobs(options.max_jobs().value_or(2 * workers)), + timeout(options.timeout()), + enforce_ordering(options.enforce_ordering()), + drop_last(options.drop_last()) {} + + size_t batch_size; + size_t workers; + size_t max_jobs; + std::optional timeout; + bool enforce_ordering; + bool drop_last; +}; +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h new file mode 100644 index 0000000000000000000000000000000000000000..df565e97235828e5c89c76f0373bc1cdaee01287 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h new file mode 100644 index 0000000000000000000000000000000000000000..e5232ab0d7a3c2f9252917dcef442c82c0fccc41 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/base.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace torch::data::datasets { +template +class MapDataset; +template +MapDataset map(D, T); // NOLINT +} // namespace torch::data::datasets + +namespace torch::data::datasets { +namespace detail { +template +struct is_optional : std::false_type {}; +template +struct is_optional> : std::true_type {}; +} // namespace detail + +/// A dataset that can yield data only in batches. +template < + typename Self, + typename Batch = std::vector>, + typename BatchRequest = ArrayRef> +class BatchDataset { + public: + using SelfType = Self; + using BatchType = Batch; + using BatchRequestType = BatchRequest; + constexpr static bool is_stateful = detail::is_optional::value; + + virtual ~BatchDataset() = default; + + /// Returns a batch of data given an index. + virtual Batch get_batch(BatchRequest request) = 0; + + /// Returns the size of the dataset, or an empty std::optional if it is + /// unsized. + virtual std::optional size() const = 0; + + /// Creates a `MapDataset` that applies the given `transform` to this dataset. + template + MapDataset map(TransformType transform) & { + return datasets::map(static_cast(*this), std::move(transform)); + } + + /// Creates a `MapDataset` that applies the given `transform` to this dataset. + template + MapDataset map(TransformType transform) && { + return datasets::map( + std::move(static_cast(*this)), std::move(transform)); + } +}; + +/// A dataset that can yield data in batches, or as individual examples. +/// +/// A `Dataset` is a `BatchDataset`, because it supports random access and +/// therefore batched access is implemented (by default) by calling the random +/// access indexing function for each index in the requested batch of indices. +/// This can be customized. +template > +class Dataset : public BatchDataset> { + public: + using ExampleType = SingleExample; + + /// Returns the example at the given index. + virtual ExampleType get(size_t index) = 0; + + /// Returns a batch of data. + /// The default implementation calls `get()` for every requested index + /// in the batch. + std::vector get_batch(ArrayRef indices) override { + std::vector batch; + batch.reserve(indices.size()); + for (const auto i : indices) { + batch.push_back(get(i)); + } + return batch; + } +}; + +/// A `StreamDataset` represents a dataset that is a potentially infinite +/// stream. It takes as batch index only a number, which is the batch size, and +/// yields that many elements from the stream. +template >> +using StreamDataset = BatchDataset; +} // namespace torch::data::datasets diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h new file mode 100644 index 0000000000000000000000000000000000000000..1eba537c44c2814e3ae02fa128097e29d232f7e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -0,0 +1,527 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace torch::data::datasets { + +/// Interface for chunk reader, which performs data chunking and reading of +/// entire chunks. +/// +/// A chunk could be an entire file, such as an audio data file or an image, +/// or part of a file in the case of a large text-file split based on seek +/// positions. +template < + typename ExampleType_, + typename ChunkType_ = std::vector> +class ChunkDataReader { + public: + virtual ~ChunkDataReader() = default; + + using ChunkType = ChunkType_; + using ExampleType = ExampleType_; + + /// Read an entire chunk. + virtual ChunkType read_chunk(size_t chunk_index) = 0; + + /// Returns the number of chunks available in this reader. + virtual size_t chunk_count() = 0; + + /// This will clear any internal state associate with this reader. + virtual void reset() = 0; +}; + +namespace detail { +/// BatchDataBuffer manages a queue of UnwrappedBatchData. After a new chunk is +/// loaded, BatchDataBuffer splits it into small batches and push them into the +/// queue. When get_batch is called from data loader, it pops cached batches and +/// return. If the cache is empty, it either waits to load more chunks or return +/// null if all chunks are loaded. +template < + typename UnwrappedBatch, + typename ExampleSampler = samplers::RandomSampler> +class BatchDataBuffer { + public: + using UnwrappedBatchType = UnwrappedBatch; + using BatchType = std::optional; + using BatchRequestType = typename ExampleSampler::BatchRequestType; + + BatchDataBuffer( + size_t batch_size, + ExampleSampler& example_sampler, + size_t queue_capacity) + : batch_size_(batch_size), + example_sampler_(example_sampler), + queue_capacity_(queue_capacity) {} + + /// Return batch data from the queue. Called from the ChunkDataset main + /// thread. + BatchType get_batch() { + std::unique_lock lock(queue_mutex_); + cv_read_.wait(lock, [this] { + // wait till there is available data in the queue or if all chunks are + // loaded (i.e. the dataset is exhausted for this epoch) + return ( + this->total_example_count_in_queue_ >= batch_size_ || this->stop_); + }); + if (batch_queue_.empty()) { + AT_ASSERT(stop_); + // All batches have been retrieved. Return an empty batch. + return std::nullopt; + } + + UnwrappedBatchData batch = std::move(batch_queue_.front()); + batch_queue_.pop(); + if (batch.exception) { + throw WorkerException(batch.exception); + } + + total_example_count_in_queue_ -= batch.batch_data.size(); + lock.unlock(); + cv_write_.notify_all(); + + return batch.batch_data; + } + + /// Push preloaded chunks to batch queue. Called from the ChunkDataset worker + /// threads. + void add_chunk_data(UnwrappedBatchType data) { + std::unique_lock lock(queue_mutex_); + cv_write_.wait(lock, [this] { + // stop loading if we have preloaded enough data. + return this->total_example_count_in_queue_ < this->queue_capacity_ || + this->stop_; + }); + if (stop_) { + // When stop_ is true, it means no further chunk loading is necessary. + // Return without any further processing. + return; + } + + auto data_size = data.size(); + auto remaining_size = data_size; + example_sampler_.reset(data_size); + + auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) { + auto batch_example_indices = this->example_sampler_.next(example_count); + AT_ASSERT( + batch_example_indices && + batch_example_indices.value().size() == example_count); + BatchRequestType& indices = batch_example_indices.value(); + for (size_t i : indices) { + TORCH_CHECK(i < data_size, "Index out of range"); + batch.emplace_back(std::move(data[i])); + } + remaining_size -= example_count; + }; + + if (!batch_queue_.empty()) { + // if the queue has existing data, and the last batch doesn't have enough + // examples to fill a batch_size batch, add more example to this batch + // first. + auto& batch = batch_queue_.back(); + size_t current_count = batch.batch_data.size(); + if (current_count < batch_size_) { + auto example_count = + std::min(remaining_size, batch_size_ - current_count); + fill_batch(example_count, batch.batch_data); + } + } + + // If we still have data remaining after filling the last pushed batch, add + // them to the queue too. + while (remaining_size > 0) { + UnwrappedBatchType current_batch; + + // Allocate the batch memory ahead of time. + current_batch.reserve(batch_size_); + + auto example_count = std::min(remaining_size, batch_size_); + fill_batch(example_count, current_batch); + batch_queue_.emplace(std::move(current_batch)); + } + total_example_count_in_queue_ += data_size; + lock.unlock(); + cv_read_.notify_all(); + } + + /// Push exceptions thrown during preloading into batch queue. Called from + /// the ChunkDataset worker threads. + void add_chunk_data(std::exception_ptr e_ptr) { + std::unique_lock lock(queue_mutex_); + cv_write_.wait(lock, [this] { + // stop loading if we have preloaded enough data. + return ( + this->total_example_count_in_queue_ < this->queue_capacity_ || + this->stop_); + }); + if (stop_) { + // When stop_ is true, it means this current thread needs to be tore down, + // the batch buffer will be discarded, so no need to enqueue any new + // exceptions. + return; + } + + batch_queue_.emplace(e_ptr); + lock.unlock(); + cv_read_.notify_all(); + } + + void stop() { + { + // Hold the lock before changing stop_ to prevent a race condition which + // can cause a deadlock. To be more specific, conditional variable + // cv_write_ waits on predicate stop_ in add_chunk_data(). The wait + // happens in two steps: 1) while still holding the lock, check if + // predicate is true; 2) if it is true, proceeds, otherwise, release the + // lock and wait until notified. Without holding a lock, cv_write_'s + // notification can happen in between step 1) and 2). In that case, as + // cv_write_ is not in waiting status yet, so the notification is lost and + // cv_write_ will sleep forever. By taking a lock before changing + // predicate stop_, it is ensured updating and evaluating stop_ always + // happen in a synchronized way + std::lock_guard lock(queue_mutex_); + stop_ = true; + } + + // notify all writers, wake them from wait to exit current method. + cv_write_.notify_all(); + // notify all readers too. + cv_read_.notify_all(); + } + /// The batch size is needed to create batches from the chunk data. Similar to + /// regular dataloader where the batches are created with prefetches, + /// BatchDataBuffer perform the batch creation using the provided batch size. + size_t batch_size_ = 0; + + /// count of total example stored in the queue + size_t total_example_count_in_queue_ = 0; + + /// struct that contains a raw unwrapped batch unit. An unwrapped batch unit + /// is the raw data without 'optional' wrapper. It can be a collection of + /// images, utterances, e.t.c. + struct UnwrappedBatchData { + explicit UnwrappedBatchData(UnwrappedBatchType data) + : batch_data(std::move(data)) {} + + explicit UnwrappedBatchData(std::exception_ptr e) + : exception(std::move(e)) {} + + /// batch data to return + UnwrappedBatchType batch_data; + + /// exception pointer which captures any abnormal exceptions while creating + /// the batch. + std::exception_ptr exception; + }; + + /// local cache to store example batches from loaded chunk + std::queue batch_queue_; + + // sync batch_queue_ update. + std::mutex queue_mutex_; + + std::condition_variable cv_read_; + std::condition_variable cv_write_; + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + ExampleSampler& example_sampler_; + + // configurable maximum number of elements the queue can hold at one time. + size_t queue_capacity_; + + // When set to true, it wakes the writer threads from the wait and exit + // current function call. This is needed when ChunkDataSet.Reset is called + // while the previous epoch is not exhausted yet. When ChunkDataset is waiting + // its preloader to finish previous work before tearing down the thread, the + // preloader could be still waiting for the conditional variable, thus cause + // the program to hang. This boolean is used to break this waiting condition. + bool stop_ = false; +}; +} // namespace detail + +/// Options to configure a `ChunkDataset`. +struct ChunkDatasetOptions { + ChunkDatasetOptions() = delete; + ChunkDatasetOptions( + size_t preloader_count, + size_t batch_size, + size_t cache_size = 2048, + size_t cross_chunk_shuffle_count = 1) + : preloader_count_(preloader_count), + batch_size_(batch_size), + cache_size_(cache_size), + cross_chunk_shuffle_count_(cross_chunk_shuffle_count) { + TORCH_CHECK( + preloader_count_ > 0, + "Preloader count is 0. At least one preloader needs to be specified."); + TORCH_CHECK( + batch_size_ > 0, + "Batch size is 0. A positive batch size needs to be specified."); + TORCH_CHECK( + cache_size_ > 0, + "Cache size is 0. A positive cache size needs to be specified."); + TORCH_CHECK( + cache_size_ >= batch_size_, + "Cache size is less than batch size. Cache needs to be large enough to " + "hold at least one batch."); + TORCH_CHECK( + cross_chunk_shuffle_count_ > 0, + "cross_chunk_shuffle_count needs to be greater than 0."); + } + + /// The number of worker thread to preload chunk data. + TORCH_ARG(size_t, preloader_count); + + /// The size of each batch. + TORCH_ARG(size_t, batch_size); + + /// The capacity of the queue for batch caching. + TORCH_ARG(size_t, cache_size) = 2048; + + // The number of chunks to perform cross-chunk shuffling. Default to 1 meaning + // no cross-chunk shuffling. When it is equal to n (n > 1), n random + // chunks will be loaded at once and example shuffling will be performed + // across all those n chunks. + // Note: Usually the default config (1 chunk shuffle + example shuffle) is + // good enough to generate random distributed data. Use this parameter only if + // you know cross-shuffle is needed in your case. Also there is a performance + // penalty when this value is greater than 1, as we need to do extra merge + // between multiple chunks before performing example sampling. + TORCH_ARG(size_t, cross_chunk_shuffle_count) = 1; +}; + +/// A stateful dataset that support hierarchical sampling and prefetching of +/// entre chunks. +/// +/// Unlike regular dataset, chunk dataset require two samplers to operate and +/// keeps an internal state. `ChunkSampler` selects, which chunk to load next, +/// while the `ExampleSampler` determines the order of Examples that are +/// returned in each `get_batch` call. The hierarchical sampling approach used +/// here is inspired by this paper +/// http://martin.zinkevich.org/publications/nips2010.pdf +template < + typename ChunkReader, + typename ChunkSampler = samplers::RandomSampler, + typename ExampleSampler = samplers::RandomSampler> +class ChunkDataset final + : public StatefulDataset< + ChunkDataset, + typename ChunkReader::BatchType, + size_t> { + public: + using BatchType = std::optional; + using UnwrappedBatchType = typename ChunkReader::BatchType; + using BatchRequestType = size_t; + using ChunkSamplerType = ChunkSampler; + using ExampleSamplerType = ExampleSampler; + + ChunkDataset( + ChunkReader chunk_reader, + ChunkSampler chunk_sampler, + ExampleSampler example_sampler, + ChunkDatasetOptions options, + std::function preprocessing_policy = + std::function()) + : chunk_reader_(std::move(chunk_reader)), + chunk_sampler_(std::move(chunk_sampler)), + example_sampler_(std::move(example_sampler)), + options_(options), + preprocessing_policy_(std::move(preprocessing_policy)), + quit_worker_(false), + running_preloaders_(0) {} + + ~ChunkDataset() override { + // stop batch buffer first. + if (batch_buffer_) { + batch_buffer_->stop(); + } + free_workers(); + } + + /// Default get_batch method of BatchDataset. This method returns + /// Example batches created from the preloaded chunks. The implementation + /// is dataset agnostic and does not need overriding in different chunk + /// datasets. + BatchType get_batch(size_t batch_size) override { + TORCH_CHECK( + batch_buffer_ != nullptr, + "Dataset needs to call reset() before calling get_batch()."); + + TORCH_CHECK( + batch_size == options_.batch_size(), + "The requested batch size does not match with the initialized batch size.\n" + " The requested batch size is ", + batch_size, + ", while the dataset is created with batch size equal to ", + options_.batch_size()); + return batch_buffer_->get_batch(); + } + + /// Helper method around get_batch as `batch_size` is not strictly necessary + BatchType get_batch() { + return get_batch(options_.batch_size()); + } + + /// This will clear any internal state and starts the internal prefetching + /// mechanism for the chunk dataset. + void reset() override { + // We need this to support partial data reads via dataloader iterator. + if (batch_buffer_) { + batch_buffer_->stop(); + } + // free workers from previous reset if there is any. + free_workers(); + preload_threads_.clear(); + + if (!load_checkpoint_) { + chunk_reader_.reset(); + chunk_sampler_.reset(chunk_reader_.chunk_count()); + load_checkpoint_ = false; + } + + // Throw out any existing cached batch in the buffer and re-creates a new + // chunk buffer. + batch_buffer_ = std::make_unique< + detail::BatchDataBuffer>( + options_.batch_size(), example_sampler_, options_.cache_size()); + + // create new workers for this new epoch. + quit_worker_ = false; + + AT_ASSERT(running_preloaders_ == 0); + running_preloaders_ = options_.preloader_count(); + for (const auto i : c10::irange(options_.preloader_count())) { + preload_threads_.emplace_back([this, i]() { this->preloader(i); }); + } + } + + /// size is not used for chunk dataset. + std::optional size() const override { + return std::nullopt; + } + + // provide a references to chunk sampler. Used mainly in distributed data + // loading to set the epoch number for the sampler. + ChunkSamplerType& chunk_sampler() { + return chunk_sampler_; + } + + void save(serialize::OutputArchive& archive) const override { + std::lock_guard lock(chunk_index_guard_); + chunk_sampler_.save(archive); + } + + void load(serialize::InputArchive& archive) override { + std::lock_guard lock(chunk_index_guard_); + chunk_sampler_.load(archive); + load_checkpoint_ = true; + } + + private: + /// running on worker thread to preload chunk data. + void preloader(size_t id) { + while (!quit_worker_.load()) { + try { + std::vector chunk_idx; + { + std::lock_guard lock(chunk_index_guard_); + if (auto chunk_sampler_result = chunk_sampler_.next( + this->options_.cross_chunk_shuffle_count())) { + chunk_idx = chunk_sampler_result.value(); + } else { + break; + } + } + UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_idx[0]); + for (const auto i : c10::irange(1, chunk_idx.size())) { + auto chunk_data = chunk_reader_.read_chunk(chunk_idx[i]); + std::move( + chunk_data.begin(), chunk_data.end(), std::back_inserter(data)); + } + if (preprocessing_policy_) { + preprocessing_policy_(data); + } + if (!data.empty()) { // skip empty chunks. + batch_buffer_->add_chunk_data(std::move(data)); + } + } catch (...) { + batch_buffer_->add_chunk_data(std::current_exception()); + } + } + AT_ASSERT(running_preloaders_.load() > 0); + --running_preloaders_; + if (running_preloaders_.load() == 0) { + // all preloaders are completed, so we can notify the batch_buffer. + batch_buffer_->stop(); + } + } + + /// Block the current thread until the workers finish execution and exit. + void free_workers() { + if (!quit_worker_.load()) { + quit_worker_ = true; + for (auto& worker_thread : preload_threads_) { + worker_thread.join(); + } + } + } + + private: + // Templated class that defines what is a chunk and how to read chunk data. + // When a chunk is returned by chunk_reader_, ChunkDataset split it into + // batches and caches them in batch_buffer_. + ChunkReader chunk_reader_; + + // chunk sampler to shuffle different chunks + ChunkSamplerType chunk_sampler_; + + // example sampler to shuffle examples in a specific chunk + ExampleSamplerType example_sampler_; + + // batch data buffer which holds chunk data from preloading thread. + std::shared_ptr< + detail::BatchDataBuffer> + batch_buffer_; + + // worker thread pool + std::vector preload_threads_; + + /// The options the Dataset was configured with. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const ChunkDatasetOptions options_; + + // function pointer wrapper to apply custom processing over chunk data. This + // is considered an advanced parameter for developers who want to apply a + // pre-process to the chunk data before sampling into minibatch. + // Different than the collate function, this policy is applied on the chunk + // level, instead of minibatch level. When a chunk of data is loaded (multiple + // chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is + // applied to the full loaded data. It is useful if developers want to + // perform pre-processing (like bucketing) to the chunk data before + // example sampler samples the data. By default it's an empty pointer and no + // action will be taken. + std::function preprocessing_policy_; + + // indicate whether the worker thread can be teared down + std::atomic quit_worker_; + + // keep track of running preloaders to notify batch buffer. A value 0 + // indicates that the chunk loading is completed. + std::atomic running_preloaders_; + + // mutex to synchronize chunk sampler next() call. + mutable std::mutex chunk_index_guard_; + + // boolean value to indicate whether we need to load the checkpoint for + // chunk_sampler_. + bool load_checkpoint_{false}; +}; +} // namespace torch::data::datasets diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h new file mode 100644 index 0000000000000000000000000000000000000000..6c4afd95501e9240acd127d0f12a2f0585a24645 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/map.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include + +namespace torch::data::datasets { +namespace detail { +template +using optional_if_t = std::conditional_t, T>; +} // namespace detail + +/// A `MapDataset` is a dataset that applies a transform to a source dataset. +template +class MapDataset : public BatchDataset< + MapDataset, + detail::optional_if_t< + SourceDataset::is_stateful, + typename AppliedTransform::OutputBatchType>, + typename SourceDataset::BatchRequestType> { + public: + using DatasetType = SourceDataset; + using TransformType = AppliedTransform; + using BatchRequestType = typename SourceDataset::BatchRequestType; + using OutputBatchType = detail::optional_if_t< + SourceDataset::is_stateful, + typename AppliedTransform::OutputBatchType>; + + MapDataset(DatasetType dataset, TransformType transform) + : dataset_(std::move(dataset)), transform_(std::move(transform)) {} + + /// Gets a batch from the source dataset and applies the transform to it, + /// returning the result. + OutputBatchType get_batch(BatchRequestType indices) override { + return get_batch_impl(std::move(indices)); + } + + /// Returns the size of the source dataset. + // NOLINTNEXTLINE(bugprone-exception-escape) + std::optional size() const noexcept override { + return dataset_.size(); + } + + /// Calls `reset()` on the underlying dataset. + /// NOTE: Stateless datasets do not have a reset() method, so a call to this + /// method will only compile for stateful datasets (which have a reset() + /// method). + void reset() { + dataset_.reset(); + } + + /// Returns the underlying dataset. + const SourceDataset& dataset() noexcept { + return dataset_; + } + + /// Returns the transform being applied. + const AppliedTransform& transform() noexcept { + return transform_; + } + + private: + /// The implementation of `get_batch()` for the stateless case, which simply + /// applies the transform to the output of `get_batch()` from the dataset. + template < + typename D = SourceDataset, + typename = std::enable_if_t> + OutputBatchType get_batch_impl(BatchRequestType indices) { + return transform_.apply_batch(dataset_.get_batch(std::move(indices))); + } + + /// The implementation of `get_batch()` for the stateful case. Here, we follow + /// the semantics of `Optional.map()` in many functional languages, which + /// applies a transformation to the optional's content when the optional + /// contains a value, and returns a new optional (of a different type) if the + /// original optional returned by `get_batch()` was empty. + template + std::enable_if_t get_batch_impl( + BatchRequestType indices) { + if (auto batch = dataset_.get_batch(std::move(indices))) { + return transform_.apply_batch(std::move(*batch)); + } + return std::nullopt; + } + + /// The underlying dataset being transformed. + SourceDataset dataset_; + + // The transformation that is applied to batches received from the dataset. + AppliedTransform transform_; +}; + +/// Creates a `MapDataset` with the given dataset and transform. +template +MapDataset map( + DatasetType dataset, + TransformType transform) { + static_assert( + std::is_same_v< + std::conditional_t< + DatasetType::is_stateful, + typename DatasetType::BatchType::value_type, + typename DatasetType::BatchType>, + typename TransformType::InputBatchType>, + "BatchType type of dataset does not match input type of transform"); + return {std::move(dataset), std::move(transform)}; +} + +} // namespace torch::data::datasets diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h new file mode 100644 index 0000000000000000000000000000000000000000..c19a862ba99f705f357bbb821448801dd8649b3f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace torch::data::datasets { +/// The MNIST dataset. +class TORCH_API MNIST : public Dataset { + public: + /// The mode in which the dataset is loaded. + enum class Mode { kTrain, kTest }; + + /// Loads the MNIST dataset from the `root` path. + /// + /// The supplied `root` path should contain the *content* of the unzipped + /// MNIST dataset, available from http://yann.lecun.com/exdb/mnist. + explicit MNIST(const std::string& root, Mode mode = Mode::kTrain); + + /// Returns the `Example` at the given `index`. + Example<> get(size_t index) override; + + /// Returns the size of the dataset. + std::optional size() const override; + + /// Returns true if this is the training subset of MNIST. + // NOLINTNEXTLINE(bugprone-exception-escape) + bool is_train() const noexcept; + + /// Returns all images stacked into a single tensor. + const Tensor& images() const; + + /// Returns all targets stacked into a single tensor. + const Tensor& targets() const; + + private: + Tensor images_, targets_; +}; +} // namespace torch::data::datasets diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h new file mode 100644 index 0000000000000000000000000000000000000000..725cfb5ffdf4a7b284ca5b25066393436b0e12c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/shared.h @@ -0,0 +1,79 @@ +#pragma once + +#include + +#include +#include + +namespace torch::data::datasets { + +/// A dataset that wraps another dataset in a shared pointer and implements the +/// `BatchDataset` API, delegating all calls to the shared instance. This is +/// useful when you want all worker threads in the dataloader to access the same +/// dataset instance. The dataset must take care of synchronization and +/// thread-safe access itself. +/// +/// Use `torch::data::datasets::make_shared_dataset()` to create a new +/// `SharedBatchDataset` like you would a `std::shared_ptr`. +template +class SharedBatchDataset : public BatchDataset< + SharedBatchDataset, + typename UnderlyingDataset::BatchType, + typename UnderlyingDataset::BatchRequestType> { + public: + using BatchType = typename UnderlyingDataset::BatchType; + using BatchRequestType = typename UnderlyingDataset::BatchRequestType; + + /// Constructs a new `SharedBatchDataset` from a `shared_ptr` to the + /// `UnderlyingDataset`. + /* implicit */ SharedBatchDataset( + std::shared_ptr shared_dataset) + : dataset_(std::move(shared_dataset)) {} + + /// Calls `get_batch` on the underlying dataset. + BatchType get_batch(BatchRequestType request) override { + return dataset_->get_batch(std::move(request)); + } + + /// Returns the `size` from the underlying dataset. + std::optional size() const override { + return dataset_->size(); + } + + /// Accesses the underlying dataset. + UnderlyingDataset& operator*() { + return *dataset_; + } + + /// Accesses the underlying dataset. + const UnderlyingDataset& operator*() const { + return *dataset_; + } + + /// Accesses the underlying dataset. + UnderlyingDataset* operator->() { + return dataset_.get(); + } + + /// Accesses the underlying dataset. + const UnderlyingDataset* operator->() const { + return dataset_.get(); + } + + /// Calls `reset()` on the underlying dataset. + void reset() { + dataset_->reset(); + } + + private: + std::shared_ptr dataset_; +}; + +/// Constructs a new `SharedBatchDataset` by creating a +/// `shared_ptr`. All arguments are forwarded to +/// `make_shared`. +template +SharedBatchDataset make_shared_dataset(Args&&... args) { + return std::make_shared(std::forward(args)...); +} +} // namespace torch::data::datasets diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h new file mode 100644 index 0000000000000000000000000000000000000000..adc210fcf3d5e27baae7f17aca549e88a1e244b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::data::datasets { + +/// A stateful dataset is a dataset that maintains some internal state, which +/// will be `reset()` at the beginning of each epoch. Subclasses can override +/// the `reset()` method to configure this behavior. Further, the return type of +/// a stateful dataset's `get_batch()` method is always an `optional`. When the +/// stateful dataset wants to indicate to the dataloader that its epoch has +/// ended, it should return an empty optional. The dataloader knows to modify +/// its implementation based on whether the dataset is stateless or stateful. +/// +/// Note that when subclassing a from `StatefulDataset`, the return +/// type of `get_batch()`, which the subclass must override, will be +/// `optional` (i.e. the type specified in the `StatefulDataset` +/// specialization is automatically boxed into an `optional` for the dataset's +/// `BatchType`). +template < + typename Self, + typename Batch = std::vector>, + typename BatchRequest = size_t> +class StatefulDataset + : public BatchDataset, BatchRequest> { + public: + /// Resets internal state of the dataset. + virtual void reset() = 0; + + /// Saves the statefulDataset's state to OutputArchive. + virtual void save(serialize::OutputArchive& archive) const = 0; + + /// Deserializes the statefulDataset's state from the `archive`. + virtual void load(serialize::InputArchive& archive) = 0; +}; + +/// Serializes a statefulDataset to `OutputArchive`. +template +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const StatefulDataset& statefulDataset) { + statefulDataset.save(archive); + return archive; +} + +/// Deserializes a statefulDataset from an `InputArchive`. +template +serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + StatefulDataset& statefulDataset) { + statefulDataset.load(archive); + return archive; +} + +} // namespace torch::data::datasets diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..1c9fd2130fe649aba36fb8c3b2dcbdc541a068e1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::data::datasets { + +/// A dataset of tensors. +/// Stores a single tensor internally, which is then indexed inside `get()`. +struct TensorDataset : public Dataset { + /// Creates a `TensorDataset` from a vector of tensors. + explicit TensorDataset(const std::vector& tensors) + : TensorDataset(torch::stack(tensors)) {} + + explicit TensorDataset(torch::Tensor tensor) : tensor(std::move(tensor)) {} + + /// Returns a single `TensorExample`. + TensorExample get(size_t index) override { + return tensor[static_cast(index)]; + } + + /// Returns the number of tensors in the dataset. + std::optional size() const override { + return tensor.size(0); + } + + Tensor tensor; +}; + +} // namespace torch::data::datasets diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h new file mode 100644 index 0000000000000000000000000000000000000000..6538c2b449c8efc27ac568d6f5728d512edb7848 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include + +namespace torch::data::detail { + +/// Encapsulates the full life cycle of DataLoader jobs. +/// +/// When a new job is enqueued to the `DataShuttle`, a counter for in-flight +/// jobs is bumped. This job is said to be "in-flight" until its result is +/// popped. Worker threads dequeue jobs as soon as they are available. When a +/// worker finishes a job, it enqueues the result. Only when the main thread +/// dequeues a result is the count of in-flight jobs decremented. When the main +/// thread attempts to dequeue a job but no jobs are in-flight, that means the +/// epoch is complete and `pop_result` returns an empty optional. +template +class DataShuttle { + public: + /// Pushes a new job. Called by the main thread. + void push_job(Job job) { + new_jobs_.push(std::move(job)); + ++in_flight_jobs_; + } + + /// Pushes the result of a job. Called by worker threads. + void push_result(Result result) { + results_.push(std::move(result)); + } + + /// Returns the next job, blocking until there is one available. Called by + /// worker threads. + Job pop_job() { + return new_jobs_.pop(); + } + + /// Returns the result of a job, or nullopt if all jobs were exhausted. Called + /// by the main thread. + std::optional pop_result( + std::optional timeout = std::nullopt) { + if (in_flight_jobs_ > 0) { + auto result = results_.pop(timeout); + --in_flight_jobs_; + return result; + } + return std::nullopt; + } + + /// Discards any jobs that are not yet in flight, and waits for all in-flight + /// jobs to finish, discarding their result. + void drain() { + // Clear all inputs so that no further jobs are scheduled. + auto number_cleared = new_jobs_.clear(); + in_flight_jobs_ -= number_cleared; + // Remove any outstanding results. + while (in_flight_jobs_ > 0) { + pop_result(); + } + } + + /// Returns the number of jobs that are still in progress. + /// When this number is zero, an epoch is finished. + size_t in_flight_jobs() const noexcept { + return in_flight_jobs_; + } + + private: + /// The queue for jobs that are not yet in flight. + Queue new_jobs_; + /// The number of in-flight jobs. + /// NOTE: Not atomic because only manipulated by the main thread. + size_t in_flight_jobs_ = 0; + /// The queue for results of finished jobs. + Queue results_; +}; + +} // namespace torch::data::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h new file mode 100644 index 0000000000000000000000000000000000000000..71752d1af3f78b617daa42f9db5ff2b58011b2ca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/queue.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +namespace torch::data::detail { + +/// A basic locked, blocking MPMC queue. +/// +/// Every `push` and `pop` is guarded by a mutex. A condition variable is used +/// to communicate insertion of new elements, such that waiting threads will be +/// woken up if they are currently waiting inside a call to `pop()`. +/// +/// Note that this data structure is written specifically for use with the +/// `DataLoader`. Its behavior is tailored to this use case and may not be +/// applicable to more general uses. +template +class Queue { + public: + /// Pushes a new value to the back of the `Queue` and notifies one thread on + /// the waiting side about this event. + void push(T value) { + { + std::lock_guard lock(mutex_); + queue_.push(std::move(value)); + } + cv_.notify_one(); + } + + /// Blocks until at least one element is ready to be popped from the front of + /// the queue. An optional `timeout` in seconds can be used to limit the time + /// spent waiting for an element. If the wait times out, an exception is + /// raised. + T pop(std::optional timeout = std::nullopt) { + std::unique_lock lock(mutex_); + if (timeout) { + if (!cv_.wait_for( + lock, *timeout, [this] { return !this->queue_.empty(); })) { + // clang-format off + TORCH_CHECK(false, + "Timeout in DataLoader queue while waiting for next batch" + " (timeout was ", timeout->count(), " ms)"); + // clang-format on + } + } else { + cv_.wait(lock, [this] { return !this->queue_.empty(); }); + } + AT_ASSERT(!queue_.empty()); + T value = queue_.front(); + queue_.pop(); + lock.unlock(); + return value; + } + + /// Empties the queue and returns the number of elements that were present at + /// the start of the function. No threads are notified about this event as it + /// is assumed to be used to drain the queue during shutdown of a + /// `DataLoader`. + size_t clear() { + std::lock_guard lock(this->mutex_); + const auto size = queue_.size(); + while (!queue_.empty()) { + queue_.pop(); + } + return size; + } + + private: + std::queue queue_; + std::mutex mutex_; + std::condition_variable cv_; +}; +} // namespace torch::data::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h new file mode 100644 index 0000000000000000000000000000000000000000..69004d55fefe5f2947ef90dabf4357c64d3b1d7f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h @@ -0,0 +1,107 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch::data::detail::sequencers { +namespace detail { +template +bool buffer_contains_result(const std::vector>& buffer) { + return std::any_of( + buffer.begin(), buffer.end(), [](const std::optional& result) { + return result.has_value(); + }); +} +} // namespace detail + +/// A `Sequencer` accepts a function that yields the next result of a +/// `DataLoader` and then has the opportunity to influence the order in which +/// these results are returned. The `NoSequencer` does not enforce any +/// sequencing and returns any result directly. The `OrderedSequencer` instead +/// buffers results internally to return them in order of their sequence number. +template +struct Sequencer { + using ResultProducer = std::function()>; + virtual ~Sequencer() = default; + virtual std::optional next(ResultProducer next_result) = 0; +}; + +/// A `Sequencer` that does not enforce any ordering. It is effectively the +/// identity function. +template +struct NoSequencer final : public Sequencer { + using typename Sequencer::ResultProducer; + std::optional next(ResultProducer next_result) override { + return next_result(); + } +}; + +/// A `Sequencer` that buffers results and returns them in order of their +/// sequence number. The `OrderedSequencer` maintains an internal, monotonically +/// incrementing counter for the next sequence number it expects. If it receives +/// a result with a higher sequence number, it will buffer it for later (when +/// the sequence number reaches that of this result). Otherwise, if the sequence +/// numbers match, the result is returned. +/// +/// Implementation note: The `OrderedSequencer` is implemented with a fixed-size +/// buffer. Let `m` be the maximum number of jobs in the data loader's queue and +/// `s` be the current sequence number. Assume `m` jobs are scheduled in the +/// `DataLoader`. Any new result is stored at index `job.sqn mod m` in the +/// `OrderedSequencer`. Why are we sure sequence numbers of new jobs will not +/// collide with sequence numbers of buffered jobs? The `OrderedSequencer` will +/// not return from `next()` until it receives the result with sqn `s`. This +/// means no new jobs can be scheduled in the `DataLoader` in the meantime, +/// which enforces that as long as sqn `s` has not been received, `s + m` (which +/// would cause a collision in the fixed-size buffer) will not yet be scheduled. +template +struct OrderedSequencer : public Sequencer { + using typename Sequencer::ResultProducer; + + /// Constructs the `OrderedSequencer` with the maximum number of results it + /// will ever hold at one point in time. + explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {} + + /// Buffers results until the next one in the expected order is received. + std::optional next(ResultProducer next_result) override { + // If we already have the result for the next sqn, return it. + if (auto& maybe_result = buffer(next_sequence_number_)) { + auto result = std::move(*maybe_result); + buffer(next_sequence_number_++).reset(); + return result; + } + // Otherwise wait for the next result. + while (true) { + auto result = next_result(); + if (!result) { + AT_ASSERT(!detail::buffer_contains_result(buffer_)); + break; + } + // If it was not nullopt and the sequence numbers match, return it + // directly and bump the sequence number. + if (result->sequence_number == next_sequence_number_) { + ++next_sequence_number_; + return result; + } + // Stash the result for later. + AT_ASSERT(!buffer(result->sequence_number).has_value()); + buffer(result->sequence_number) = std::move(result); + } + // The result was an empty optional, so we are done with this epoch. + return std::nullopt; + } + + /// Accesses the buffer at the `index` modulo the buffer size. + std::optional& buffer(size_t index) { + return buffer_.at(index % buffer_.size()); + } + + /// The monotonically increasing sequence number we expect. + size_t next_sequence_number_ = 0; + + /// A fixed-size buffer (after construction). + std::vector> buffer_; +}; +} // namespace torch::data::detail::sequencers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h new file mode 100644 index 0000000000000000000000000000000000000000..af4b08371a82b7c907cb1ae74631bae9357abc54 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/example.h @@ -0,0 +1,53 @@ +#pragma once + +#include + +namespace torch::data { + +/// An `Example` from a dataset. +/// +/// A dataset consists of data and an associated target (label). +template +struct Example { + using DataType = Data; + using TargetType = Target; + + Example() = default; + Example(Data data, Target target) + : data(std::move(data)), target(std::move(target)) {} + + Data data; + Target target; +}; + +namespace example { +using NoTarget = void; +} // namespace example + +/// A specialization for `Example` that does not have a target. +/// +/// This class exists so that code can be written for a templated `Example` +/// type, and work both for labeled and unlabeled datasets. +template +struct Example { + using DataType = Data; + using TargetType = example::NoTarget; + + Example() = default; + /* implicit */ Example(Data data) : data(std::move(data)) {} + + // When a DataLoader returns an Example like this, that example should be + // implicitly convertible to the underlying data type. + + operator Data&() { + return data; + } + operator const Data&() const { + return data; + } + + Data data; +}; + +using TensorExample = Example; +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..a0ee28a73e0180f823b02d07b676188cb1b1dc2e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/iterator.h @@ -0,0 +1,178 @@ +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace torch::data { +namespace detail { +// For increased safety and more separated logic, this implementation of +// `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A +// `ValidIterator` yields new batches until the `DataLoader` is exhausted. While +// the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are +// the same object. When the `ValidIterator` becomes exhausted, it compares +// equal to the `SentinelIterator`, but not before. Half the code here is to +// implement double dispatch for the comparison. Got damnit, C++. + +template +struct ValidIterator; + +template +struct SentinelIterator; + +/// Base class for the `ValidIterator` and `SentinelIterator` +template +struct IteratorImpl { + virtual ~IteratorImpl() = default; + virtual void next() = 0; + virtual Batch& get() = 0; + virtual bool operator==(const IteratorImpl& other) const = 0; + virtual bool operator==(const ValidIterator& other) const = 0; + virtual bool operator==(const SentinelIterator& other) const = 0; +}; + +template +struct ValidIterator : public IteratorImpl { + using BatchProducer = std::function()>; + + explicit ValidIterator(BatchProducer next_batch) + : next_batch_(std::move(next_batch)) {} + + /// Fetches the next batch. + void next() override { + // If we didn't get the very first batch yet, get it now. + lazy_initialize(); + TORCH_CHECK( + batch_.has_value(), "Attempted to increment iterator past the end"); + // Increment to the next batch. + batch_ = next_batch_(); + } + + /// Returns the current batch. The precondition for this operation to not + /// throw an exception is that it has been compared to the `SentinelIterator` + /// and did not compare equal. + Batch& get() override { + // If we didn't get the very first batch yet, get it now. + lazy_initialize(); + TORCH_CHECK( + batch_.has_value(), + "Attempted to dereference iterator that was past the end"); + return batch_.value(); + } + + /// Does double dispatch. + bool operator==(const IteratorImpl& other) const override { + return other == *this; + } + + /// A `ValidIterator` is equal to the `SentinelIterator` iff. the + /// `ValidIterator` has reached the end of the dataloader. + bool operator==(const SentinelIterator& /* unused */) const override { + lazy_initialize(); + return !batch_; + } + + /// Returns true if the memory address of `other` equals that of `this`. + bool operator==(const ValidIterator& other) const override { + return &other == this; + } + + /// Gets the very first batch if it has not yet been fetched. + void lazy_initialize() const { + if (!initialized_) { + batch_ = next_batch_(); + initialized_ = true; + } + } + + BatchProducer next_batch_; + mutable std::optional batch_; + mutable bool initialized_ = false; +}; + +template +struct SentinelIterator : public IteratorImpl { + void next() override { + TORCH_CHECK( + false, + "Incrementing the DataLoader's past-the-end iterator is not allowed"); + } + + Batch& get() override { + TORCH_CHECK( + false, + "Dereferencing the DataLoader's past-the-end iterator is not allowed"); + } + + /// Does double dispatch. + bool operator==(const IteratorImpl& other) const override { + return other == *this; + } + + /// Calls the comparison operator between `ValidIterator` and + /// `SentinelIterator`. + bool operator==(const ValidIterator& other) const override { + return other == *this; + } + + /// Sentinel iterators always compare equal. + bool operator==(const SentinelIterator& other) const override { + return true; + } +}; +} // namespace detail + +template +class Iterator { + public: + // Type aliases to make the class recognized as a proper iterator. + using difference_type = std::ptrdiff_t; + using value_type = Batch; + using pointer = Batch*; + using reference = Batch&; + using iterator_category = std::input_iterator_tag; + + explicit Iterator(std::unique_ptr> impl) + : impl_(std::move(impl)) {} + + /// Increments the iterator. + /// Only permitted for valid iterators (not past the end). + Iterator& operator++() { + impl_->next(); + return *this; + } + + /// Returns the current batch. + /// Only permitted for valid iterators (not past the end). + Batch& operator*() { + return impl_->get(); + } + + /// Returns a pointer to the current batch. + /// Only permitted for valid iterators (not past the end). + Batch* operator->() { + return &impl_->get(); + } + + /// Compares two iterators for equality. + bool operator==(const Iterator& other) const { + return *impl_ == *other.impl_; + } + + /// Compares two iterators for inequality. + bool operator!=(const Iterator& other) const { + return !(*this == other); + } + + private: + /// Points either to a `ValidIterator` or to a `SentinelIterator`. + std::shared_ptr> impl_; +}; +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h new file mode 100644 index 0000000000000000000000000000000000000000..928a2412aa76f8a22574b433a2f61152c45ae5c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h new file mode 100644 index 0000000000000000000000000000000000000000..ebaf40848abcef249e0fab687484fa5874bbb1e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/base.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::data::samplers { +/// A `Sampler` is an object that yields an index with which to access a +/// dataset. +template > +class Sampler { + public: + using BatchRequestType = BatchRequest; + + virtual ~Sampler() = default; + + /// Resets the `Sampler`'s internal state. + /// Typically called before a new epoch. + /// Optionally, accepts a new size when resetting the sampler. + virtual void reset(std::optional new_size) = 0; + + /// Returns the next index if possible, or an empty optional if the + /// sampler is exhausted for this epoch. + virtual std::optional next(size_t batch_size) = 0; + + /// Serializes the `Sampler` to the `archive`. + virtual void save(serialize::OutputArchive& archive) const = 0; + + /// Deserializes the `Sampler` from the `archive`. + virtual void load(serialize::InputArchive& archive) = 0; +}; + +} // namespace torch::data::samplers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h new file mode 100644 index 0000000000000000000000000000000000000000..7132856fe235951be707420bc788cb8ed86dc774 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace torch::data::samplers { +/// A base class for custom index types. +struct TORCH_API CustomBatchRequest { + CustomBatchRequest() = default; + CustomBatchRequest(const CustomBatchRequest&) = default; + CustomBatchRequest(CustomBatchRequest&&) noexcept = default; + virtual ~CustomBatchRequest() = default; + + /// The number of elements accessed by this index. + virtual size_t size() const = 0; +}; +} // namespace torch::data::samplers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h new file mode 100644 index 0000000000000000000000000000000000000000..64be81645dcc6cd7b8eccee14f28ed9f43ca918d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h @@ -0,0 +1,133 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::data::samplers { + +/// A `Sampler` that selects a subset of indices to sample from and defines a +/// sampling behavior. In a distributed setting, this selects a subset of the +/// indices depending on the provided num_replicas and rank parameters. The +/// `Sampler` performs a rounding operation based on the `allow_duplicates` +/// parameter to decide the local sample count. +template > +class DistributedSampler : public Sampler { + public: + DistributedSampler( + size_t size, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true) + : size_(size), + num_replicas_(num_replicas), + rank_(rank), + + allow_duplicates_(allow_duplicates) {} + + /// Set the epoch for the current enumeration. This can be used to alter the + /// sample selection and shuffling behavior. + void set_epoch(size_t epoch) { + epoch_ = epoch; + } + + size_t epoch() const { + return epoch_; + } + + protected: + size_t local_sample_count() { + if (allow_duplicates_) { + return (size_ + num_replicas_ - 1) / num_replicas_; + } else { + return size_ / num_replicas_; + } + } + + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t size_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t num_replicas_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t rank_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + size_t epoch_{0}; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool allow_duplicates_; +}; + +/// Select samples randomly. The sampling order is shuffled at each `reset()` +/// call. +class TORCH_API DistributedRandomSampler : public DistributedSampler<> { + public: + DistributedRandomSampler( + size_t size, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true); + + /// Resets the `DistributedRandomSampler` to a new set of indices. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `DistributedRandomSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `DistributedRandomSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `DistributedRandomSampler`. + size_t index() const noexcept; + + private: + void populate_indices(); + + size_t begin_index_; + size_t end_index_; + size_t sample_index_; + std::vector all_indices_; +}; + +/// Select samples sequentially. +class TORCH_API DistributedSequentialSampler : public DistributedSampler<> { + public: + DistributedSequentialSampler( + size_t size, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true); + + /// Resets the `DistributedSequentialSampler` to a new set of indices. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `DistributedSequentialSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `DistributedSequentialSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `DistributedSequentialSampler`. + size_t index() const noexcept; + + private: + void populate_indices(); + + size_t begin_index_; + size_t end_index_; + size_t sample_index_; + std::vector all_indices_; +}; + +} // namespace torch::data::samplers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h new file mode 100644 index 0000000000000000000000000000000000000000..fc81aae7c3b527823b85157569cfbf61fa041d70 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/random.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::data::samplers { + +/// A `Sampler` that returns random indices. +class TORCH_API RandomSampler : public Sampler<> { + public: + /// Constructs a `RandomSampler` with a size and dtype for the stored indices. + /// + /// The constructor will eagerly allocate all required indices, which is the + /// sequence `0 ... size - 1`. `index_dtype` is the data type of the stored + /// indices. You can change it to influence memory usage. + explicit RandomSampler(int64_t size, Dtype index_dtype = torch::kInt64); + + ~RandomSampler() override; + + /// Resets the `RandomSampler` to a new set of indices. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `RandomSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `RandomSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `RandomSampler`. + size_t index() const noexcept; + + private: + at::Tensor indices_; + int64_t index_ = 0; +}; +} // namespace torch::data::samplers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h new file mode 100644 index 0000000000000000000000000000000000000000..2b57f90d116f5242bc4f679c24157dba180969af --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::data::samplers { + +/// A `Sampler` that returns indices sequentially. +class TORCH_API SequentialSampler : public Sampler<> { + public: + /// Creates a `SequentialSampler` that will return indices in the range + /// `0...size - 1`. + explicit SequentialSampler(size_t size); + + /// Resets the `SequentialSampler` to zero. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns the next batch of indices. + std::optional> next(size_t batch_size) override; + + /// Serializes the `SequentialSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `SequentialSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + /// Returns the current index of the `SequentialSampler`. + size_t index() const noexcept; + + private: + size_t size_; + size_t index_{0}; +}; + +} // namespace torch::data::samplers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..8c87a9b3d00e298edf0395d13d276bdb950f5559 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +namespace torch::data::samplers { +/// Serializes a `Sampler` into an `OutputArchive`. +template +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const Sampler& sampler) { + sampler.save(archive); + return archive; +} + +/// Deserializes a `Sampler` from an `InputArchive`. +template +serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + Sampler& sampler) { + sampler.load(archive); + return archive; +} +} // namespace torch::data::samplers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h new file mode 100644 index 0000000000000000000000000000000000000000..c5eb8214cdf64cbed5f8f30687d6aac4c408de8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/samplers/stream.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch::serialize { +class InputArchive; +class OutputArchive; +} // namespace torch::serialize + +namespace torch::data::samplers { + +/// A wrapper around a batch size value, which implements the +/// `CustomBatchRequest` interface. +struct TORCH_API BatchSize : public CustomBatchRequest { + explicit BatchSize(size_t size); + size_t size() const noexcept override; + operator size_t() const noexcept; + size_t size_; +}; + +/// A sampler for (potentially infinite) streams of data. +/// +/// The major feature of the `StreamSampler` is that it does not return +/// particular indices, but instead only the number of elements to fetch from +/// the dataset. The dataset has to decide how to produce those elements. +class TORCH_API StreamSampler : public Sampler { + public: + /// Constructs the `StreamSampler` with the number of individual examples that + /// should be fetched until the sampler is exhausted. + explicit StreamSampler(size_t epoch_size); + + /// Resets the internal state of the sampler. + void reset(std::optional new_size = std::nullopt) override; + + /// Returns a `BatchSize` object with the number of elements to fetch in the + /// next batch. This number is the minimum of the supplied `batch_size` and + /// the difference between the `epoch_size` and the current index. If the + /// `epoch_size` has been reached, returns an empty optional. + std::optional next(size_t batch_size) override; + + /// Serializes the `StreamSampler` to the `archive`. + void save(serialize::OutputArchive& archive) const override; + + /// Deserializes the `StreamSampler` from the `archive`. + void load(serialize::InputArchive& archive) override; + + private: + size_t examples_retrieved_so_far_ = 0; + size_t epoch_size_; +}; + +} // namespace torch::data::samplers diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h new file mode 100644 index 0000000000000000000000000000000000000000..e5d92062e62d52dd2dac3ab39f76385f9bf1522f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h new file mode 100644 index 0000000000000000000000000000000000000000..b2ee9ed81f6b5f6b8a8681798d425ff1f404112a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/base.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +#include +#include + +namespace torch::data::transforms { + +/// A transformation of a batch to a new batch. +template +class BatchTransform { + public: + using InputBatchType = InputBatch; + using OutputBatchType = OutputBatch; + + virtual ~BatchTransform() = default; + + /// Applies the transformation to the given `input_batch`. + virtual OutputBatch apply_batch(InputBatch input_batch) = 0; +}; + +/// A transformation of individual input examples to individual output examples. +/// +/// Just like a `Dataset` is a `BatchDataset`, a `Transform` is a +/// `BatchTransform` that can operate on the level of individual examples rather +/// than entire batches. The batch-level transform is implemented (by default) +/// in terms of the example-level transform, though this can be customized. +template +class Transform + : public BatchTransform, std::vector> { + public: + using InputType = Input; + using OutputType = Output; + + /// Applies the transformation to the given `input`. + virtual OutputType apply(InputType input) = 0; + + /// Applies the `transformation` over the entire `input_batch`. + std::vector apply_batch(std::vector input_batch) override { + std::vector output_batch; + output_batch.reserve(input_batch.size()); + for (auto&& input : input_batch) { + output_batch.push_back(apply(std::move(input))); + } + return output_batch; + } +}; +} // namespace torch::data::transforms diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h new file mode 100644 index 0000000000000000000000000000000000000000..8905fc7f7c9361349540ad45111035f70bc6f47f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/collate.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include + +namespace torch::data::transforms { + +/// A `Collation` is a transform that reduces a batch into a single value. +/// The result is a `BatchDataset` that has the type of the single value as its +/// `BatchType`. +template > +using Collation = BatchTransform; + +/// A `Collate` allows passing a custom function to reduce/collate a batch +/// into a single value. It's effectively the lambda version of `Collation`, +/// which you could subclass and override `operator()` to achieve the same. +/// +/// \rst +/// .. code-block:: cpp +/// using namespace torch::data; +/// +/// auto dataset = datasets::MNIST("path/to/mnist") +/// .map(transforms::Collate>([](std::vector> e) { +/// return std::move(e.front()); +/// })); +/// \endrst +template > +using Collate = BatchLambda; +} // namespace torch::data::transforms diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h new file mode 100644 index 0000000000000000000000000000000000000000..c9cfa15431b26296365d7c8f452397109048a1a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h @@ -0,0 +1,52 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch::data::transforms { + +/// A `BatchTransform` that applies a user-provided functor to a batch. +template +class BatchLambda : public BatchTransform { + public: + using typename BatchTransform::InputBatchType; + using typename BatchTransform::OutputBatchType; + using FunctionType = std::function; + + /// Constructs the `BatchLambda` from the given `function` object. + explicit BatchLambda(FunctionType function) + : function_(std::move(function)) {} + + /// Applies the user-provided function object to the `input_batch`. + OutputBatchType apply_batch(InputBatchType input_batch) override { + return function_(std::move(input_batch)); + } + + private: + FunctionType function_; +}; + +// A `Transform` that applies a user-provided functor to individual examples. +template +class Lambda : public Transform { + public: + using typename Transform::InputType; + using typename Transform::OutputType; + using FunctionType = std::function; + + /// Constructs the `Lambda` from the given `function` object. + explicit Lambda(FunctionType function) : function_(std::move(function)) {} + + /// Applies the user-provided function object to the `input`. + OutputType apply(InputType input) override { + return function_(std::move(input)); + } + + private: + FunctionType function_; +}; + +} // namespace torch::data::transforms diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h new file mode 100644 index 0000000000000000000000000000000000000000..26063db4ea8535a2cbab3b338a740bf708edfcc1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/stack.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::data::transforms { + +template > +struct Stack; + +/// A `Collation` for `Example` types that stacks all data +/// tensors into one tensor, and all target (label) tensors into one tensor. +template <> +struct Stack> : public Collation> { + Example<> apply_batch(std::vector> examples) override { + std::vector data, targets; + data.reserve(examples.size()); + targets.reserve(examples.size()); + for (auto& example : examples) { + data.push_back(std::move(example.data)); + targets.push_back(std::move(example.target)); + } + return {torch::stack(data), torch::stack(targets)}; + } +}; + +/// A `Collation` for `Example` types that stacks all data +/// tensors into one tensor. +template <> +struct Stack + : public Collation> { + TensorExample apply_batch(std::vector examples) override { + std::vector data; + data.reserve(examples.size()); + for (auto& example : examples) { + data.push_back(std::move(example.data)); + } + return torch::stack(data); + } +}; +} // namespace torch::data::transforms diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..7b6280bd96859023e3c420999fab8d1c48b7ff2e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::data::transforms { + +/// A `Transform` that is specialized for the typical `Example` +/// combination. It exposes a single `operator()` interface hook (for +/// subclasses), and calls this function on input `Example` objects. +template +class TensorTransform + : public Transform, Example> { + public: + using E = Example; + using typename Transform::InputType; + using typename Transform::OutputType; + + /// Transforms a single input tensor to an output tensor. + virtual Tensor operator()(Tensor input) = 0; + + /// Implementation of `Transform::apply` that calls `operator()`. + OutputType apply(InputType input) override { + input.data = (*this)(std::move(input.data)); + return input; + } +}; + +/// A `Lambda` specialized for the typical `Example` input type. +template +class TensorLambda : public TensorTransform { + public: + using FunctionType = std::function; + + /// Creates a `TensorLambda` from the given `function`. + explicit TensorLambda(FunctionType function) + : function_(std::move(function)) {} + + /// Applies the user-provided functor to the input tensor. + Tensor operator()(Tensor input) override { + return function_(std::move(input)); + } + + private: + FunctionType function_; +}; + +/// Normalizes input tensors by subtracting the supplied mean and dividing by +/// the given standard deviation. +template +struct Normalize : public TensorTransform { + /// Constructs a `Normalize` transform. The mean and standard deviation can be + /// anything that is broadcastable over the input tensors (like single + /// scalars). + Normalize(ArrayRef mean, ArrayRef stddev) + : mean(torch::tensor(mean, torch::kFloat32) + .unsqueeze(/*dim=*/1) + .unsqueeze(/*dim=*/2)), + stddev(torch::tensor(stddev, torch::kFloat32) + .unsqueeze(/*dim=*/1) + .unsqueeze(/*dim=*/2)) {} + + torch::Tensor operator()(Tensor input) override { + return input.sub(mean).div(stddev); + } + + torch::Tensor mean, stddev; +}; +} // namespace torch::data::transforms diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h new file mode 100644 index 0000000000000000000000000000000000000000..afaf369e5537660cc9f65915851d783d9b930d27 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/data/worker_exception.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +namespace torch::data { + +/// An exception thrown when a DataLoader's worker thread throws an exception, +/// which is caught. A `WorkerException` stores an `exception_ptr` to the +/// original exception thrown in the worker thread. +struct WorkerException : public std::exception { + /// Constructs a `WorkerException` from an `exception_ptr`. + explicit WorkerException(std::exception_ptr original) + // NOLINTNEXTLINE(bugprone-throw-keyword-missing) + : original_exception(std::move(original)), + message("Caught exception in DataLoader worker thread.") { + try { + std::rethrow_exception(original_exception); + } catch (std::exception& e) { + message += " Original message: "; + message += e.what(); + } + } + + const char* what() const noexcept override { + return message.c_str(); + } + + /// The original exception thrown in the worker thread. + std::exception_ptr original_exception; + + /// This exception's message (not the original exception's message). + std::string message; +}; + +} // namespace torch::data diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h new file mode 100644 index 0000000000000000000000000000000000000000..9485af1d297d27a0251da41661252fabf9837746 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h @@ -0,0 +1,349 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +#include + +namespace torch::detail { + +enum class TensorDataContainerType { Scalar, InitList, Tensor }; + +struct TensorDataContainer; + +inline std::ostream& operator<<( + std::ostream& stream, + const TensorDataContainer& tensor_data_container); + +inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) { + if (scalar_type == at::kInt || scalar_type == at::kLong) { + // C++ `torch::tensor` with an integer type or an `at::ArrayRef` / + // `std::vector` / (nested) braced-init-list of integer types always + // produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python + // `torch.tensor` behavior. + return at::kLong; + } else if (scalar_type == at::kFloat || scalar_type == at::kDouble) { + // C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / + // `std::vector` / (nested) braced-init-list of floating-point types always + // produces a tensor of dtype `torch::get_default_dtype()`, matching Python + // `torch.tensor` behavior. + return at::typeMetaToScalarType(at::get_default_dtype()); + } else { + return scalar_type; + } +} + +// We use `TensorDataContainer` to support converting the following data +// container types into the equivalent Tensor: +// +// 1. Arbitrarily nested braced-init-list (e.g. `{{1, 2}, {3, 4}}`). +// 2. `at::ArrayRef` of supported tensor data types. +// 3. `std::vector` of supported tensor data types. +// +// At any time, a `TensorDataContainer` object represents one of the following: +// +// 1. A scalar with value `scalar()` and type `scalar_type()`. +// 2. A Tensor represented in `std::initializer_list` form, +// with value `init_list()`, Tensor scalar type `scalar_type()`, and Tensor +// sizes `sizes()`. +// 3. A Tensor represented in `at::Tensor` form, with value `tensor()`, scalar +// type `scalar_type()`, +// and Tensor sizes `sizes()`. +// +// All the infrastructure here is mostly to support converting an arbitrarily +// nested braced-init-list to the equivalent Tensor successfully. Consider the +// following example: +// +// `torch::tensor({{1}, {2}})` +// +// this will call into the `torch::tensor` function: +// +// `at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const +// at::TensorOptions& options = {})` +// +// the compiler will first try to convert `{{1}, {2}}` to `TensorDataContainer` +// type: +// +// `TensorDataContainer({{1}, {2}})` +// +// which matches to the +// `TensorDataContainer(std::initializer_list)` +// constructor, and in an attempt to convert `{1}` and `{2}` to +// `TensorDataContainer`, it calls the following: +// +// `TensorDataContainer({1})` (same call path happens for `{2}`, and we'll just +// focus on `{1}` here) +// +// At this point, theoretically there are two plausible ways for `{1}` to be +// matched to one of the constructors of `TensorDataContainer`: +// +// 1. It can be a list-initialization of a scalar value, thus matching +// `TensorDataContainer(int value)`. +// 2. It can be converted to `std::initializer_list`, thus +// matching +// `TensorDataContainer(std::initializer_list)`. +// +// How does the compiler decide which one to choose? According to +// `https://en.cppreference.com/w/cpp/language/list_initialization`, +// braced-init-list always prefers the constructor that takes +// `std::initializer_list`. Hence we happily move forward with constructor #2, +// and it calls the following: +// +// `TensorDataContainer(1)` +// +// Now it matches `TensorDataContainer(int value)`, which stores `1` as a scalar +// value. All is good. +struct TensorDataContainer { + // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{}, + // {}})`), the innermost empty braced-init-list `{}` matches the default + // constructor of the innermost `TensorDataContainer`. + TensorDataContainer() + : sizes_({0}), + // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g. + // `torch.tensor([[], []])`) depends on the value of + // `torch.get_default_dtype()`, and we should do the same for the C++ + // equivalent. + scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())), + type_(TensorDataContainerType::InitList) {} +#define TENSOR(T, S) \ + TensorDataContainer(T value) \ + : scalar_type_(at::k##S), \ + type_(TensorDataContainerType::Scalar), \ + scalar_(value) {} + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) + AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR + TensorDataContainer(std::initializer_list init_list) + : scalar_type_(init_list.begin()->scalar_type()), + type_(TensorDataContainerType::InitList), + init_list_(init_list) { + const TensorDataContainer& first_elem = *(init_list.begin()); + for (const auto& elem : init_list) { + TORCH_CHECK( + elem.sizes() == first_elem.sizes(), + "Expected all sub-lists to have sizes: ", + first_elem.sizes(), + " (e.g. ", + first_elem, + "), ", + "but got sub-list ", + elem, + " with sizes: ", + elem.sizes()); + TORCH_CHECK( + elem.scalar_type() == first_elem.scalar_type(), + "Expected all elements of the tensor to have the same scalar type: ", + first_elem.scalar_type(), + ", but got element of scalar type: ", + elem.scalar_type()); + } + sizes_.reserve(first_elem.sizes().size() + 1); + sizes_.push_back(static_cast(init_list.size())); + sizes_.insert( + sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end()); + } + +#define TENSOR(T, S) \ + TensorDataContainer(at::ArrayRef values) \ + : sizes_({(int64_t)values.size()}), \ + scalar_type_(at::k##S), \ + type_(TensorDataContainerType::Tensor) { \ + at::AutoDispatchBelowAutograd mode; \ + if (scalar_type_ == at::kBool) { \ + tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \ + } else { \ + tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \ + } \ + } + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) + AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR + + // NOTE: We need to handle `std::vector` explicitly instead of relying on an + // implicit conversion to `at::ArrayRef`, otherwise the following error can be + // thrown when calling `torch::tensor(std::vector({1, 2}))`: + // ``` + // error: no matching function for call to 'tensor(const std::vector&)' + // no known conversion for argument 1 from 'const std::vector' to + // 'torch::detail::TensorDataContainer' + // ``` + // + // NOTE: `torch::tensor(std::vector)` is not supported for now, because + // ArrayRef cannot be constructed from a std::vector bitfield. +#define TENSOR(T, S) \ + TensorDataContainer(const std::vector& values) \ + : TensorDataContainer(at::ArrayRef(values)) {} + AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR) + AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR + + bool is_scalar() const { + return type_ == TensorDataContainerType::Scalar; + } + + const c10::Scalar& scalar() const { + TORCH_CHECK( + is_scalar(), + "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`"); + return scalar_; + } + + bool is_init_list() const { + return type_ == TensorDataContainerType::InitList; + } + + const std::initializer_list& init_list() const { + TORCH_CHECK( + is_init_list(), + "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`"); + return init_list_; + } + + bool is_tensor() const { + return type_ == TensorDataContainerType::Tensor; + } + + const at::Tensor& tensor() const { + TORCH_CHECK( + is_tensor(), + "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`"); + return tensor_; + } + + const std::vector& sizes() const { + return sizes_; + } + + const c10::ScalarType& scalar_type() const { + return scalar_type_; + } + + at::Tensor convert_to_tensor(at::TensorOptions options) const { + if (!options.has_dtype()) { + options = options.dtype(compute_desired_dtype(scalar_type_)); + } + + if (is_scalar()) { + at::AutoDispatchBelowAutograd mode; + return at::scalar_tensor(scalar_, options); + } else if (is_init_list()) { + // NOTE: Here we explicitly choose to initialize the tensor on CPU first, + // fill each element of the tensor, and then move the tensor to the + // desired device. For CUDA device, this approach only involves 1 CUDA + // kernel launch, and is much faster than initializing the tensor on CUDA + // first and then filling each element of it (which involves `N` CUDA + // kernel launches where `N` is the number of the elements in the tensor). + at::Tensor tensor = ([&]() { + at::AutoDispatchBelowAutograd mode; + return at::empty(sizes_, options.device(at::kCPU)); + })(); + fill_tensor(tensor); + return tensor.to(options.device()); + } else if (is_tensor()) { + auto output = tensor_.to(options); + TORCH_CHECK( + !tensor_.is_complex() || output.is_complex(), + "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information"); + return output; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); + } + } + + void pretty_print_recursive(std::ostream& stream) const { + if (is_scalar()) { + AT_DISPATCH_ALL_TYPES_AND3( + at::kBool, + at::kHalf, + at::kBFloat16, + scalar_type_, + "TensorDataContainer_pretty_print_scalar", + [&] { stream << scalar_.to(); }); + } else if (is_init_list()) { + stream << "{"; + for (const TensorDataContainer* it = init_list_.begin(); + it != init_list_.end(); + it++) { + stream << *it; + if (std::next(it) != init_list_.end()) + stream << ", "; + } + stream << "}"; + } else if (is_tensor()) { + stream << "{"; + for (const auto i : c10::irange(tensor_.sizes()[0])) { + AT_DISPATCH_ALL_TYPES_AND3( + at::kBool, + at::kHalf, + at::kBFloat16, + scalar_type_, + "TensorDataContainer_pretty_print_tensor_item", + [&] { stream << tensor_[i].item(); }); + if (i != tensor_.sizes()[0] - 1) + stream << ", "; + } + stream << "}"; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); + } + } + + private: + void fill_tensor(at::Tensor& tensor) const { + if (is_scalar()) { + TORCH_INTERNAL_ASSERT( + tensor.dim() == 0, + "Expected a 0-dim Tensor, but got Tensor with dimensions: ", + tensor.dim()); + at::NoGradGuard guard; + tensor.fill_(scalar_); + } else if (is_init_list()) { + TORCH_INTERNAL_ASSERT( + tensor.sizes()[0] == (int64_t)init_list_.size(), + "Expected a Tensor with size ", + init_list_.size(), + " in its first dimension, but got Tensor with size ", + tensor.sizes()[0], + " in its first dimension"); + int64_t index = 0; + for (const auto& elem : init_list_) { + at::Tensor slice = tensor[index]; + elem.fill_tensor(slice); + index++; + } + } else if (is_tensor()) { + TORCH_INTERNAL_ASSERT( + false, + "TensorDataContainer is already a Tensor type, `fill_tensor` should not be called"); + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorDataContainer type"); + } + } + + std::vector sizes_; + c10::ScalarType scalar_type_; + TensorDataContainerType type_; + c10::Scalar scalar_; + std::initializer_list init_list_; + at::Tensor tensor_; +}; + +inline std::ostream& operator<<( + std::ostream& stream, + const TensorDataContainer& tensor_data_container) { + tensor_data_container.pretty_print_recursive(stream); + return stream; +} + +} // namespace torch::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h new file mode 100644 index 0000000000000000000000000000000000000000..3a8a58b598737bce8e5f831cfea902fb266e82c4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/detail/static.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch::nn { +class Module; +} // namespace torch::nn + +namespace torch::detail { +/// Detects if a type T has a forward() method. +template +struct has_forward { + // Declare two types with differing size. + using yes = int8_t; + using no = int16_t; + + // Here we declare two functions. The first is only enabled if `&U::forward` + // is well-formed and returns the `yes` type. In C++, the ellipsis parameter + // type (`...`) always puts the function at the bottom of overload resolution. + // This is specified in the standard as: 1) A standard conversion sequence is + // always better than a user-defined conversion sequence or an ellipsis + // conversion sequence. 2) A user-defined conversion sequence is always better + // than an ellipsis conversion sequence This means that if the first overload + // is viable, it will be preferred over the second as long as we pass any + // convertible type. The type of `&U::forward` is a pointer type, so we can + // pass e.g. 0. + template + static yes test(decltype(&U::forward)); + template + static no test(...); + + // Finally we test statically whether the size of the type returned by the + // selected overload is the size of the `yes` type. + static constexpr bool value = (sizeof(test(nullptr)) == sizeof(yes)); +}; + +template +constexpr bool check_not_lvalue_references() { + return (!std::is_lvalue_reference_v || + std::is_const_v>) && + check_not_lvalue_references(); +} + +template <> +inline constexpr bool check_not_lvalue_references() { + return true; +} + +/// A type trait whose `value` member is true if `M` derives from `Module`. +template +using is_module = std::is_base_of>; + +template +using enable_if_module_t = std::enable_if_t::value, T>; +} // namespace torch::detail diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/enum.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/enum.h new file mode 100644 index 0000000000000000000000000000000000000000..195b776b672d857cdceb72fa2c92eedfcd1fcecb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/enum.h @@ -0,0 +1,210 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#define TORCH_ENUM_DECLARE(name) \ + namespace torch { \ + namespace enumtype { \ + /* \ + NOTE: We need to provide the default constructor for each struct, \ + otherwise Clang 3.8 would complain: \ + ``` \ + error: default initialization of an object of const type 'const \ + enumtype::Enum1' without a user-provided default constructor \ + ``` \ + */ \ + struct k##name { \ + k##name() {} \ + }; \ + } \ + TORCH_API extern const enumtype::k##name k##name; \ + } + +#define TORCH_ENUM_DEFINE(name) \ + namespace torch { \ + const enumtype::k##name k##name; \ + } + +#define TORCH_ENUM_PRETTY_PRINT(name) \ + std::string operator()(const enumtype::k##name& v [[maybe_unused]]) const { \ + std::string k("k"); \ + return k + #name; \ + } + +// NOTE: Backstory on why we need the following two macros: +// +// Consider the following options class: +// +// ``` +// struct TORCH_API SomeOptions { +// typedef std::variant +// reduction_t; SomeOptions(reduction_t reduction = torch::kMean) : +// reduction_(reduction) {} +// +// TORCH_ARG(reduction_t, reduction); +// }; +// ``` +// +// and the functional that uses it: +// +// ``` +// Tensor some_functional( +// const Tensor& input, +// SomeOptions options = {}) { +// ... +// } +// ``` +// +// Normally, we would expect this to work: +// +// `F::some_functional(input, torch::kNone)` +// +// However, it throws the following error instead: +// +// ``` +// error: could not convert `torch::kNone` from `const torch::enumtype::kNone` +// to `torch::nn::SomeOptions` +// ``` +// +// To get around this problem, we explicitly provide the following constructors +// for `SomeOptions`: +// +// ``` +// SomeOptions(torch::enumtype::kNone reduction) : reduction_(torch::kNone) {} +// SomeOptions(torch::enumtype::kMean reduction) : reduction_(torch::kMean) {} +// SomeOptions(torch::enumtype::kSum reduction) : reduction_(torch::kSum) {} +// ``` +// +// so that the conversion from `torch::kNone` to `SomeOptions` would work. +// +// Note that we also provide the default constructor `SomeOptions() {}`, so that +// `SomeOptions options = {}` can work. +#define TORCH_OPTIONS_CTOR_VARIANT_ARG3( \ + OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \ + OPTIONS_NAME() = default; \ + OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ + OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ + OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} + +#define TORCH_OPTIONS_CTOR_VARIANT_ARG4( \ + OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3, TYPE4) \ + OPTIONS_NAME() = default; \ + OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \ + OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \ + OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} \ + OPTIONS_NAME(torch::enumtype::TYPE4 ARG_NAME) : ARG_NAME##_(torch::TYPE4) {} + +TORCH_ENUM_DECLARE(Linear) +TORCH_ENUM_DECLARE(Conv1D) +TORCH_ENUM_DECLARE(Conv2D) +TORCH_ENUM_DECLARE(Conv3D) +TORCH_ENUM_DECLARE(ConvTranspose1D) +TORCH_ENUM_DECLARE(ConvTranspose2D) +TORCH_ENUM_DECLARE(ConvTranspose3D) +TORCH_ENUM_DECLARE(Sigmoid) +TORCH_ENUM_DECLARE(Tanh) +TORCH_ENUM_DECLARE(ReLU) +TORCH_ENUM_DECLARE(GELU) +TORCH_ENUM_DECLARE(SiLU) +TORCH_ENUM_DECLARE(Mish) +TORCH_ENUM_DECLARE(LeakyReLU) +TORCH_ENUM_DECLARE(FanIn) +TORCH_ENUM_DECLARE(FanOut) +TORCH_ENUM_DECLARE(Constant) +TORCH_ENUM_DECLARE(Reflect) +TORCH_ENUM_DECLARE(Replicate) +TORCH_ENUM_DECLARE(Circular) +TORCH_ENUM_DECLARE(Nearest) +TORCH_ENUM_DECLARE(Bilinear) +TORCH_ENUM_DECLARE(Bicubic) +TORCH_ENUM_DECLARE(Trilinear) +TORCH_ENUM_DECLARE(Area) +TORCH_ENUM_DECLARE(NearestExact) +TORCH_ENUM_DECLARE(Sum) +TORCH_ENUM_DECLARE(Mean) +TORCH_ENUM_DECLARE(Max) +TORCH_ENUM_DECLARE(None) +TORCH_ENUM_DECLARE(BatchMean) +TORCH_ENUM_DECLARE(Zeros) +TORCH_ENUM_DECLARE(Border) +TORCH_ENUM_DECLARE(Reflection) +TORCH_ENUM_DECLARE(RNN_TANH) +TORCH_ENUM_DECLARE(RNN_RELU) +TORCH_ENUM_DECLARE(LSTM) +TORCH_ENUM_DECLARE(GRU) +TORCH_ENUM_DECLARE(Valid) +TORCH_ENUM_DECLARE(Same) + +namespace torch::enumtype { + +struct _compute_enum_name { + TORCH_ENUM_PRETTY_PRINT(Linear) + TORCH_ENUM_PRETTY_PRINT(Conv1D) + TORCH_ENUM_PRETTY_PRINT(Conv2D) + TORCH_ENUM_PRETTY_PRINT(Conv3D) + TORCH_ENUM_PRETTY_PRINT(ConvTranspose1D) + TORCH_ENUM_PRETTY_PRINT(ConvTranspose2D) + TORCH_ENUM_PRETTY_PRINT(ConvTranspose3D) + TORCH_ENUM_PRETTY_PRINT(Sigmoid) + TORCH_ENUM_PRETTY_PRINT(Tanh) + TORCH_ENUM_PRETTY_PRINT(ReLU) + TORCH_ENUM_PRETTY_PRINT(GELU) + TORCH_ENUM_PRETTY_PRINT(SiLU) + TORCH_ENUM_PRETTY_PRINT(Mish) + TORCH_ENUM_PRETTY_PRINT(LeakyReLU) + TORCH_ENUM_PRETTY_PRINT(FanIn) + TORCH_ENUM_PRETTY_PRINT(FanOut) + TORCH_ENUM_PRETTY_PRINT(Constant) + TORCH_ENUM_PRETTY_PRINT(Reflect) + TORCH_ENUM_PRETTY_PRINT(Replicate) + TORCH_ENUM_PRETTY_PRINT(Circular) + TORCH_ENUM_PRETTY_PRINT(Nearest) + TORCH_ENUM_PRETTY_PRINT(Bilinear) + TORCH_ENUM_PRETTY_PRINT(Bicubic) + TORCH_ENUM_PRETTY_PRINT(Trilinear) + TORCH_ENUM_PRETTY_PRINT(Area) + TORCH_ENUM_PRETTY_PRINT(NearestExact) + TORCH_ENUM_PRETTY_PRINT(Sum) + TORCH_ENUM_PRETTY_PRINT(Mean) + TORCH_ENUM_PRETTY_PRINT(Max) + TORCH_ENUM_PRETTY_PRINT(None) + TORCH_ENUM_PRETTY_PRINT(BatchMean) + TORCH_ENUM_PRETTY_PRINT(Zeros) + TORCH_ENUM_PRETTY_PRINT(Border) + TORCH_ENUM_PRETTY_PRINT(Reflection) + TORCH_ENUM_PRETTY_PRINT(RNN_TANH) + TORCH_ENUM_PRETTY_PRINT(RNN_RELU) + TORCH_ENUM_PRETTY_PRINT(LSTM) + TORCH_ENUM_PRETTY_PRINT(GRU) + TORCH_ENUM_PRETTY_PRINT(Valid) + TORCH_ENUM_PRETTY_PRINT(Same) +}; + +template +std::string get_enum_name(V variant_enum) { + return std::visit(enumtype::_compute_enum_name{}, variant_enum); +} + +template +at::Reduction::Reduction reduction_get_enum(V variant_enum) { + if (std::holds_alternative(variant_enum)) { + return at::Reduction::None; + } else if (std::holds_alternative(variant_enum)) { + return at::Reduction::Mean; + } else if (std::holds_alternative(variant_enum)) { + return at::Reduction::Sum; + } else { + TORCH_CHECK( + false, + get_enum_name(variant_enum), + " is not a valid value for reduction"); + return at::Reduction::END; + } +} + +} // namespace torch::enumtype diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/expanding_array.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/expanding_array.h new file mode 100644 index 0000000000000000000000000000000000000000..e7c834626dd7ff34d6cab106aeb1c6e30f659d34 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/expanding_array.h @@ -0,0 +1,182 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace torch { + +/// A utility class that accepts either a container of `D`-many values, or a +/// single value, which is internally repeated `D` times. This is useful to +/// represent parameters that are multidimensional, but often equally sized in +/// all dimensions. For example, the kernel size of a 2D convolution has an `x` +/// and `y` length, but `x` and `y` are often equal. In such a case you could +/// just pass `3` to an `ExpandingArray<2>` and it would "expand" to `{3, 3}`. +template +class ExpandingArray { + public: + /// Constructs an `ExpandingArray` from an `initializer_list`. The extent of + /// the length is checked against the `ExpandingArray`'s extent parameter `D` + /// at runtime. + /*implicit*/ ExpandingArray(std::initializer_list list) + : ExpandingArray(c10::ArrayRef(list)) {} + + /// Constructs an `ExpandingArray` from an `std::vector`. The extent of + /// the length is checked against the `ExpandingArray`'s extent parameter `D` + /// at runtime. + /*implicit*/ ExpandingArray(std::vector vec) + : ExpandingArray(c10::ArrayRef(vec)) {} + + /// Constructs an `ExpandingArray` from an `c10::ArrayRef`. The extent of + /// the length is checked against the `ExpandingArray`'s extent parameter `D` + /// at runtime. + /*implicit*/ ExpandingArray(c10::ArrayRef values) { + // clang-format off + TORCH_CHECK( + values.size() == D, + "Expected ", D, " values, but instead got ", values.size()); + // clang-format on + std::copy(values.begin(), values.end(), values_.begin()); + } + + /// Constructs an `ExpandingArray` from a single value, which is repeated `D` + /// times (where `D` is the extent parameter of the `ExpandingArray`). + /*implicit*/ ExpandingArray(T single_size) { + values_.fill(single_size); + } + + /// Constructs an `ExpandingArray` from a correctly sized `std::array`. + /*implicit*/ ExpandingArray(const std::array& values) + : values_(values) {} + + /// Accesses the underlying `std::array`. + std::array& operator*() { + return values_; + } + + /// Accesses the underlying `std::array`. + const std::array& operator*() const { + return values_; + } + + /// Accesses the underlying `std::array`. + std::array* operator->() { + return &values_; + } + + /// Accesses the underlying `std::array`. + const std::array* operator->() const { + return &values_; + } + + /// Returns an `ArrayRef` to the underlying `std::array`. + operator c10::ArrayRef() const { + return values_; + } + + /// Returns the extent of the `ExpandingArray`. + size_t size() const noexcept { + return D; + } + + protected: + /// The backing array. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::array values_; +}; + +template +std::ostream& operator<<( + std::ostream& stream, + const ExpandingArray& expanding_array) { + if (expanding_array.size() == 1) { + return stream << expanding_array->at(0); + } + return stream << static_cast>(expanding_array); +} + +/// A utility class that accepts either a container of `D`-many +/// `std::optional` values, or a single `std::optional` value, which is +/// internally repeated `D` times. It has the additional ability to accept +/// containers of the underlying type `T` and convert them to a container of +/// `std::optional`. +template +class ExpandingArrayWithOptionalElem + : public ExpandingArray> { + public: + using ExpandingArray>::ExpandingArray; + + /// Constructs an `ExpandingArrayWithOptionalElem` from an `initializer_list` + /// of the underlying type `T`. The extent of the length is checked against + /// the `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. + /*implicit*/ ExpandingArrayWithOptionalElem(std::initializer_list list) + : ExpandingArrayWithOptionalElem(c10::ArrayRef(list)) {} + + /// Constructs an `ExpandingArrayWithOptionalElem` from an `std::vector` of + /// the underlying type `T`. The extent of the length is checked against the + /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. + /*implicit*/ ExpandingArrayWithOptionalElem(std::vector vec) + : ExpandingArrayWithOptionalElem(c10::ArrayRef(vec)) {} + + /// Constructs an `ExpandingArrayWithOptionalElem` from an `c10::ArrayRef` of + /// the underlying type `T`. The extent of the length is checked against the + /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. + /*implicit*/ ExpandingArrayWithOptionalElem(c10::ArrayRef values) + : ExpandingArray>(0) { + // clang-format off + TORCH_CHECK( + values.size() == D, + "Expected ", D, " values, but instead got ", values.size()); + // clang-format on + for (const auto i : c10::irange(this->values_.size())) { + this->values_[i] = values[i]; + } + } + + /// Constructs an `ExpandingArrayWithOptionalElem` from a single value of the + /// underlying type `T`, which is repeated `D` times (where `D` is the extent + /// parameter of the `ExpandingArrayWithOptionalElem`). + /*implicit*/ ExpandingArrayWithOptionalElem(T single_size) + : ExpandingArray>(0) { + for (const auto i : c10::irange(this->values_.size())) { + this->values_[i] = single_size; + } + } + + /// Constructs an `ExpandingArrayWithOptionalElem` from a correctly sized + /// `std::array` of the underlying type `T`. + /*implicit*/ ExpandingArrayWithOptionalElem(const std::array& values) + : ExpandingArray>(0) { + for (const auto i : c10::irange(this->values_.size())) { + this->values_[i] = values[i]; + } + } +}; + +template +std::ostream& operator<<( + std::ostream& stream, + const ExpandingArrayWithOptionalElem& expanding_array_with_opt_elem) { + if (expanding_array_with_opt_elem.size() == 1) { + const auto& elem = expanding_array_with_opt_elem->at(0); + stream << (elem.has_value() ? c10::str(elem.value()) : "None"); + } else { + std::vector str_array; + for (const auto& elem : *expanding_array_with_opt_elem) { + str_array.emplace_back( + elem.has_value() ? c10::str(elem.value()) : "None"); + } + stream << c10::ArrayRef(str_array); + } + return stream; +} + +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/fft.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/fft.h new file mode 100644 index 0000000000000000000000000000000000000000..c1a355abd1bb7e5bccd6bc3a66f3f731ae755e24 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/fft.h @@ -0,0 +1,390 @@ +#pragma once + +#include +#include + +#include + +namespace torch::fft { + +/// Computes the 1 dimensional fast Fourier transform over a given dimension. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.fft. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kComplexDouble); +/// torch::fft::fft(t); +/// ``` +inline Tensor fft( + const Tensor& self, + std::optional n = std::nullopt, + int64_t dim = -1, + std::optional norm = std::nullopt) { + return torch::fft_fft_symint(self, std::move(n), dim, norm); +} + +/// Computes the 1 dimensional inverse Fourier transform over a given dimension. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.ifft. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kComplexDouble); +/// torch::fft::ifft(t); +/// ``` +inline Tensor ifft( + const Tensor& self, + std::optional n = std::nullopt, + int64_t dim = -1, + std::optional norm = std::nullopt) { + return torch::fft_ifft_symint(self, std::move(n), dim, norm); +} + +/// Computes the 2-dimensional fast Fourier transform over the given dimensions. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.fft2. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::fft2(t); +/// ``` +inline Tensor fft2( + const Tensor& self, + OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_fft2(self, s, dim, norm); +} + +/// Computes the inverse of torch.fft.fft2 +/// See https://pytorch.org/docs/main/fft.html#torch.fft.ifft2. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::ifft2(t); +/// ``` +inline Tensor ifft2( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_ifft2(self, s, dim, norm); +} + +/// Computes the N dimensional fast Fourier transform over given dimensions. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.fftn. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::fftn(t); +/// ``` +inline Tensor fftn( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { + return torch::fft_fftn(self, s, dim, norm); +} + +/// Computes the N dimensional fast Fourier transform over given dimensions. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.ifftn. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::ifftn(t); +/// ``` +inline Tensor ifftn( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { + return torch::fft_ifftn(self, s, dim, norm); +} + +/// Computes the 1 dimensional FFT of real input with onesided Hermitian output. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.rfft. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128); +/// auto T = torch::fft::rfft(t); +/// assert(T.is_complex() && T.numel() == 128 / 2 + 1); +/// ``` +inline Tensor rfft( + const Tensor& self, + std::optional n = std::nullopt, + int64_t dim = -1, + std::optional norm = std::nullopt) { + return torch::fft_rfft_symint(self, std::move(n), dim, norm); +} + +/// Computes the inverse of torch.fft.rfft +/// +/// The input is a onesided Hermitian Fourier domain signal, with real-valued +/// output. See https://pytorch.org/docs/main/fft.html#torch.fft.irfft +/// +/// Example: +/// ``` +/// auto T = torch::randn(128 / 2 + 1, torch::kComplexDouble); +/// auto t = torch::fft::irfft(t, /*n=*/128); +/// assert(t.is_floating_point() && T.numel() == 128); +/// ``` +inline Tensor irfft( + const Tensor& self, + std::optional n = std::nullopt, + int64_t dim = -1, + std::optional norm = std::nullopt) { + return torch::fft_irfft_symint(self, std::move(n), dim, norm); +} + +/// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian +/// output. See https://pytorch.org/docs/main/fft.html#torch.fft.rfft2 +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kDouble); +/// torch::fft::rfft2(t); +/// ``` +inline Tensor rfft2( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_rfft2(self, s, dim, norm); +} + +/// Computes the inverse of torch.fft.rfft2. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.irfft2. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::irfft2(t); +/// ``` +inline Tensor irfft2( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_irfft2(self, s, dim, norm); +} + +/// Computes the N dimensional FFT of real input with onesided Hermitian output. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.rfftn +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kDouble); +/// torch::fft::rfftn(t); +/// ``` +inline Tensor rfftn( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { + return torch::fft_rfftn(self, s, dim, norm); +} + +/// Computes the inverse of torch.fft.rfftn. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.irfftn. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 128}, dtype=kComplexDouble); +/// torch::fft::irfftn(t); +/// ``` +inline Tensor irfftn( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { + return torch::fft_irfftn(self, s, dim, norm); +} + +/// Computes the 1 dimensional FFT of a onesided Hermitian signal +/// +/// The input represents a Hermitian symmetric time domain signal. The returned +/// Fourier domain representation of such a signal is a real-valued. See +/// https://pytorch.org/docs/main/fft.html#torch.fft.hfft +/// +/// Example: +/// ``` +/// auto t = torch::randn(128 / 2 + 1, torch::kComplexDouble); +/// auto T = torch::fft::hfft(t, /*n=*/128); +/// assert(T.is_floating_point() && T.numel() == 128); +/// ``` +inline Tensor hfft( + const Tensor& self, + std::optional n = std::nullopt, + int64_t dim = -1, + std::optional norm = std::nullopt) { + return torch::fft_hfft_symint(self, std::move(n), dim, norm); +} + +/// Computes the inverse FFT of a real-valued Fourier domain signal. +/// +/// The output is a onesided representation of the Hermitian symmetric time +/// domain signal. See https://pytorch.org/docs/main/fft.html#torch.fft.ihfft. +/// +/// Example: +/// ``` +/// auto T = torch::randn(128, torch::kDouble); +/// auto t = torch::fft::ihfft(T); +/// assert(t.is_complex() && T.numel() == 128 / 2 + 1); +/// ``` +inline Tensor ihfft( + const Tensor& self, + std::optional n = std::nullopt, + int64_t dim = -1, + std::optional norm = std::nullopt) { + return torch::fft_ihfft_symint(self, std::move(n), dim, norm); +} + +/// Computes the 2-dimensional FFT of a Hermitian symmetric input signal. +/// +/// The input is a onesided representation of the Hermitian symmetric time +/// domain signal. See https://pytorch.org/docs/main/fft.html#torch.fft.hfft2. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 65}, torch::kComplexDouble); +/// auto T = torch::fft::hfft2(t, /*s=*/{128, 128}); +/// assert(T.is_floating_point() && T.numel() == 128 * 128); +/// ``` +inline Tensor hfft2( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_hfft2(self, s, dim, norm); +} + +/// Computes the 2-dimensional IFFT of a real input signal. +/// +/// The output is a onesided representation of the Hermitian symmetric time +/// domain signal. See +/// https://pytorch.org/docs/main/fft.html#torch.fft.ihfft2. +/// +/// Example: +/// ``` +/// auto T = torch::randn({128, 128}, torch::kDouble); +/// auto t = torch::fft::hfft2(T); +/// assert(t.is_complex() && t.size(1) == 65); +/// ``` +inline Tensor ihfft2( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_ihfft2(self, s, dim, norm); +} + +/// Computes the N-dimensional FFT of a Hermitian symmetric input signal. +/// +/// The input is a onesided representation of the Hermitian symmetric time +/// domain signal. See https://pytorch.org/docs/main/fft.html#torch.fft.hfftn. +/// +/// Example: +/// ``` +/// auto t = torch::randn({128, 65}, torch::kComplexDouble); +/// auto T = torch::fft::hfftn(t, /*s=*/{128, 128}); +/// assert(T.is_floating_point() && T.numel() == 128 * 128); +/// ``` +inline Tensor hfftn( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_hfftn(self, s, dim, norm); +} + +/// Computes the N-dimensional IFFT of a real input signal. +/// +/// The output is a onesided representation of the Hermitian symmetric time +/// domain signal. See +/// https://pytorch.org/docs/main/fft.html#torch.fft.ihfftn. +/// +/// Example: +/// ``` +/// auto T = torch::randn({128, 128}, torch::kDouble); +/// auto t = torch::fft::hfft2(T); +/// assert(t.is_complex() && t.size(1) == 65); +/// ``` +inline Tensor ihfftn( + const Tensor& self, + at::OptionalIntArrayRef s = std::nullopt, + IntArrayRef dim = {-2, -1}, + std::optional norm = std::nullopt) { + return torch::fft_ihfftn(self, s, dim, norm); +} + +/// Computes the discrete Fourier Transform sample frequencies for a signal of +/// size n. +/// +/// See https://pytorch.org/docs/main/fft.html#torch.fft.fftfreq +/// +/// Example: +/// ``` +/// auto frequencies = torch::fft::fftfreq(128, torch::kDouble); +/// ``` +inline Tensor fftfreq(int64_t n, double d, const TensorOptions& options = {}) { + return torch::fft_fftfreq(n, d, options); +} + +inline Tensor fftfreq(int64_t n, const TensorOptions& options = {}) { + return torch::fft_fftfreq(n, /*d=*/1.0, options); +} + +/// Computes the sample frequencies for torch.fft.rfft with a signal of size n. +/// +/// Like torch.fft.rfft, only the positive frequencies are included. +/// See https://pytorch.org/docs/main/fft.html#torch.fft.rfftfreq +/// +/// Example: +/// ``` +/// auto frequencies = torch::fft::rfftfreq(128, torch::kDouble); +/// ``` +inline Tensor rfftfreq(int64_t n, double d, const TensorOptions& options) { + return torch::fft_rfftfreq(n, d, options); +} + +inline Tensor rfftfreq(int64_t n, const TensorOptions& options) { + return torch::fft_rfftfreq(n, /*d=*/1.0, options); +} + +/// Reorders n-dimensional FFT output to have negative frequency terms first, by +/// a torch.roll operation. +/// +/// See https://pytorch.org/docs/main/fft.html#torch.fft.fftshift +/// +/// Example: +/// ``` +/// auto x = torch::randn({127, 4}); +/// auto centred_fft = torch::fft::fftshift(torch::fft::fftn(x)); +/// ``` +inline Tensor fftshift( + const Tensor& x, + at::OptionalIntArrayRef dim = std::nullopt) { + return torch::fft_fftshift(x, dim); +} + +/// Inverse of torch.fft.fftshift +/// +/// See https://pytorch.org/docs/main/fft.html#torch.fft.ifftshift +/// +/// Example: +/// ``` +/// auto x = torch::randn({127, 4}); +/// auto shift = torch::fft::fftshift(x) +/// auto unshift = torch::fft::ifftshift(shift); +/// assert(torch::allclose(x, unshift)); +/// ``` +inline Tensor ifftshift( + const Tensor& x, + at::OptionalIntArrayRef dim = std::nullopt) { + return torch::fft_ifftshift(x, dim); +} + +} // namespace torch::fft diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/imethod.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/imethod.h new file mode 100644 index 0000000000000000000000000000000000000000..1d3bdd04449de6c9a38c415c92b52e3dbbb6881b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/imethod.h @@ -0,0 +1,53 @@ +#pragma once +#include +#include + +namespace torch { + +class TORCH_API IMethod { + /* + IMethod provides a portable interface for torch methods, whether + they are backed by torchscript or python/deploy. + + This is helpful since torchscript methods provide additional information + (e.g. FunctionSchema, Graph) which aren't available in pure python methods. + + Higher level APIs should prefer depending on this interface rather + than a specific implementation of it, to promote portability and reuse, and + avoid unintentional dependencies on e.g. script methods. + + Note: This API is experimental, and may evolve. + */ + public: + using IValueList = std::vector; + using IValueMap = std::unordered_map; + + IMethod() = default; + IMethod(const IMethod&) = default; + IMethod& operator=(const IMethod&) = default; + IMethod(IMethod&&) noexcept = default; + IMethod& operator=(IMethod&&) noexcept = default; + virtual ~IMethod() = default; + + virtual c10::IValue operator()( + std::vector args, + const IValueMap& kwargs = IValueMap()) const = 0; + + virtual const std::string& name() const = 0; + + // Returns an ordered list of argument names, possible in both + // script and python methods. This is a more portable dependency + // than a ScriptMethod FunctionSchema, which has more information + // than can be generally expected from a python method. + const std::vector& getArgumentNames() const; + + protected: + virtual void setArgumentNames( + std::vector& argumentNames) const = 0; + + private: + mutable bool isArgumentNamesInitialized_{false}; + mutable std::vector argumentNames_; +}; + +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/jit.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/jit.h new file mode 100644 index 0000000000000000000000000000000000000000..19651f23ba3818b25d476a1bfd20b943bb6ccbab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/jit.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch::jit { + +/// Compiles script code into an executable graph. +/// +/// Takes a string containing functions in script syntax and compiles them into +/// a module (graph). The returned module provides a `run_method` function +/// that may be used to invoke the compiled functions. +/// +/// For example: +/// \rst +/// .. code-block:: cpp +/// +/// auto module = torch::jit::compile(R"JIT( +/// def relu_script(a, b): +/// return torch.relu(a + b) +/// def test_while(a, i): +/// while i < 10: +/// a += a +/// i += 1 +/// return a +/// )JIT"); +/// IValue output = module->run_method("relu_script", a, b); +/// \endrst +TORCH_API std::shared_ptr compile(const std::string& source); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/mps.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/mps.h new file mode 100644 index 0000000000000000000000000000000000000000..576b8835a413e9ccbaebc05c58332c78330d17b9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/mps.h @@ -0,0 +1,42 @@ +#pragma once + +#include + +#include +#include + +#ifdef __OBJC__ +#include +#include +using MTLCommandBuffer_t = id; +using DispatchQueue_t = dispatch_queue_t; +#else +using MTLCommandBuffer_t = void*; +using DispatchQueue_t = void*; +#endif + +namespace torch::mps { + +/// Returns true if MPS device is available. +bool TORCH_API is_available(); + +/// Sets the RNG seed for the MPS device. +void TORCH_API manual_seed(uint64_t seed); + +/// Waits for all streams on the MPS device to complete. +/// This blocks the calling CPU thread by using the 'waitUntilCompleted()' +/// method to wait for Metal command buffers finish executing all the +/// encoded GPU operations before returning. +void TORCH_API synchronize(); + +/// Submits the currently active command buffer to run on the MPS device. +void TORCH_API commit(); + +/// Get the current command buffer to encode the Metal commands. +MTLCommandBuffer_t TORCH_API get_command_buffer(); + +/// Get the dispatch_queue_t to synchronize encoding the custom kernels +/// with the PyTorch MPS backend. +DispatchQueue_t TORCH_API get_dispatch_queue(); + +} // namespace torch::mps diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nested.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nested.h new file mode 100644 index 0000000000000000000000000000000000000000..0340d1f2b34f476cb885ccdae36d20e49b7c5ec5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nested.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nested { + +/// Nested tensor +/// +/// See +/// https://pytorch.org/docs/main/nested.html#torch.nested.nested_tensor +/// +/// ``` +// implemented on python object to allow torch.nested.nested_tensor to be +// constructed with arbitrarily nested python objects - for now, only arbitrary +// python lists and lists of Tensors +// See torch/csrc/autograd/python_nested_functions_manual.cpp for Python +// implementation +// See here for C++ implementation +inline at::Tensor nested_tensor( + at::TensorList nested_tensor_data, + const at::TensorOptions& options = {}) { + auto out = at::_nested_tensor_from_tensor_list( + nested_tensor_data, + c10::typeMetaToScalarType(options.dtype()), + std::nullopt, + options.device(), + options.pinned_memory()); + if (options.has_requires_grad() && options.requires_grad()) { + out.requires_grad_(true); + } + return out; +} + +inline at::Tensor nested_tensor( + at::ArrayRef nested_tensor_data, + const at::TensorOptions& options = {}) { + for (const auto& tdc : nested_tensor_data) { + TORCH_CHECK( + tdc.is_init_list(), + "nested_tensor() not implemented for these parameters"); + } + // Construct a TensorList using nested_tensor_data + std::vector tensor_list(nested_tensor_data.size()); + std::transform( + nested_tensor_data.begin(), + nested_tensor_data.end(), + tensor_list.begin(), + [&](const detail::TensorDataContainer& tdc) { + return tdc.convert_to_tensor(options); + }); + auto out = at::_nested_tensor_from_tensor_list( + tensor_list, + c10::typeMetaToScalarType(options.dtype()), + std::nullopt, + options.device(), + options.pinned_memory()); + if (options.has_requires_grad() && options.requires_grad()) { + out.requires_grad_(true); + } + return out; +} + +/// As Nested Tensor +/// +/// See +/// https://pytorch.org/docs/main/nested.html#torch.nested.as_nested_tensor +/// +/// ``` +inline at::Tensor as_nested_tensor( + at::TensorList list, + std::optional dtype = std::nullopt, + std::optional device = std::nullopt) { + return at::_nested_tensor_from_tensor_list( + list, dtype, std::nullopt, device, std::nullopt); +} + +/// Nested to padded tensor +/// +/// See +/// https://pytorch.org/docs/main/nested.html#torch.nested.to_padded_tensor +/// +/// ``` +inline at::Tensor to_padded_tensor( + const at::Tensor& self, + double padding, + at::OptionalIntArrayRef output_size = std::nullopt) { + return at::nested_to_padded_tensor(self, padding, output_size); +} + +} // namespace torch::nested diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn.h new file mode 100644 index 0000000000000000000000000000000000000000..b93220b5d62a0ccf64b16a3b8aae8cb940045849 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h new file mode 100644 index 0000000000000000000000000000000000000000..9a582435815169cb62e795967434bb5d3a89fa0f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include + +namespace torch::nn { +/// The `clone()` method in the base `Module` class does not have knowledge of +/// the concrete runtime type of its subclasses. Therefore, `clone()` must +/// either be called from within the subclass, or from a base class that has +/// knowledge of the concrete type. `Cloneable` uses the CRTP to gain +/// knowledge of the subclass' static type and provide an implementation of the +/// `clone()` method. We do not want to use this pattern in the base class, +/// because then storing a module would always require templatizing it. +template +// NOLINTNEXTLINE(bugprone-exception-escape) +class Cloneable : public Module { + public: + using Module::Module; + + /// `reset()` must perform initialization of all members with reference + /// semantics, most importantly parameters, buffers and submodules. + virtual void reset() = 0; + + /// Performs a recursive "deep copy" of the `Module`, such that all parameters + /// and submodules in the cloned module are different from those in the + /// original module. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + NoGradGuard no_grad; + + const auto& self = static_cast(*this); + auto copy = std::make_shared(self); + copy->parameters_.clear(); + copy->buffers_.clear(); + copy->children_.clear(); + copy->reset(); + TORCH_CHECK( + copy->parameters_.size() == parameters_.size(), + "The cloned module does not have the same number of " + "parameters as the original module after calling reset(). " + "Are you sure you called register_parameter() inside reset() " + "and not the constructor?"); + for (const auto& parameter : named_parameters(/*recurse=*/false)) { + auto& tensor = *parameter; + auto data = device && tensor.device() != *device ? tensor.to(*device) + : tensor.clone(); + copy->parameters_[parameter.key()].set_data(data); + } + TORCH_CHECK( + copy->buffers_.size() == buffers_.size(), + "The cloned module does not have the same number of " + "buffers as the original module after calling reset(). " + "Are you sure you called register_buffer() inside reset() " + "and not the constructor?"); + for (const auto& buffer : named_buffers(/*recurse=*/false)) { + auto& tensor = *buffer; + auto data = device && tensor.device() != *device ? tensor.to(*device) + : tensor.clone(); + copy->buffers_[buffer.key()].set_data(data); + } + TORCH_CHECK( + copy->children_.size() == children_.size(), + "The cloned module does not have the same number of " + "child modules as the original module after calling reset(). " + "Are you sure you called register_module() inside reset() " + "and not the constructor?"); + for (const auto& child : children_) { + copy->children_[child.key()]->clone_(*child.value(), device); + } + return copy; + } + + private: + void clone_(Module& other, const std::optional& device) final { + // Here we are *pretty* certain that `other's` type is `Derived` (because it + // was registered under the same name as `this`), but you never know what + // crazy things `reset()` does, so `dynamic_cast` just to be safe. + auto clone = std::dynamic_pointer_cast(other.clone(device)); + TORCH_CHECK( + clone != nullptr, + "Attempted to clone submodule, but it is of a " + "different type than the submodule it was to be cloned into"); + static_cast(*this) = *clone; + } +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..b148edc68173f4d11cf58e042902edf3c508afff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/activation.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/activation.h new file mode 100644 index 0000000000000000000000000000000000000000..49de1c8af63f33fbb9071e748e02e73f213f1444 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/activation.h @@ -0,0 +1,961 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor elu(Tensor input, double alpha, bool inplace) { + if (inplace) { + return torch::elu_(input, alpha); + } else { + return torch::elu(input, alpha); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.elu +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::ELUFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::elu(x, F::ELUFuncOptions().alpha(0.42).inplace(true)); +/// ``` +inline Tensor elu(Tensor input, const ELUFuncOptions& options = {}) { + return detail::elu(std::move(input), options.alpha(), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor selu(Tensor input, bool inplace) { + if (inplace) { + return torch::selu_(input); + } else { + return torch::selu(input); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.selu +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::SELUFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::selu(input, F::SELUFuncOptions(false)); +/// ``` +inline Tensor selu(Tensor input, const SELUFuncOptions& options = {}) { + return detail::selu(std::move(input), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor hardshrink(const Tensor& input, double lambda) { + return torch::hardshrink(input, lambda); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.hardshrink +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::HardshrinkFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::hardshrink(x, F::HardshrinkFuncOptions().lambda(0.42)); +/// ``` +inline Tensor hardshrink( + const Tensor& input, + const HardshrinkFuncOptions& options = {}) { + return detail::hardshrink(input, options.lambda()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor hardtanh( + Tensor input, + double min_val, + double max_val, + bool inplace) { + if (inplace) { + return torch::hardtanh_(input, min_val, max_val); + } else { + return torch::hardtanh(input, min_val, max_val); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.hardtanh +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::HardtanhFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::hardtanh(x, +/// F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true)); +/// ``` +inline Tensor hardtanh(Tensor input, const HardtanhFuncOptions& options = {}) { + return detail::hardtanh( + std::move(input), + options.min_val(), + options.max_val(), + options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor leaky_relu(Tensor input, double negative_slope, bool inplace) { + if (inplace) { + return torch::leaky_relu_(input, negative_slope); + } else { + return torch::leaky_relu(input, negative_slope); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.leaky_relu +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::LeakyReLUFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::leaky_relu(x, +/// F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true)); +/// ``` +inline Tensor leaky_relu( + Tensor input, + const LeakyReLUFuncOptions& options = {}) { + return detail::leaky_relu( + std::move(input), options.negative_slope(), options.inplace()); +} + +// ============================================================================ + +inline Tensor logsigmoid(const Tensor& input) { + return torch::log_sigmoid(input); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor gumbel_softmax( + const Tensor& logits, + double tau, + bool hard, + int dim) { + auto gumbels = + -torch::empty_like(logits).exponential_().log(); // ~Gumbel(0,1) + gumbels = (logits + gumbels) / tau; // ~Gumbel(logits, tau) + auto y_soft = gumbels.softmax(dim); + + torch::Tensor ret; + if (hard) { + // Straight through. + auto index = std::get<1>(y_soft.max(dim, /*keepdim=*/true)); + auto y_hard = torch::zeros_like(logits).scatter_(dim, index, 1.0); + ret = y_hard - y_soft.detach() + y_soft; + } else { + ret = y_soft; + } + return ret; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.gumbel_softmax +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::GumbelSoftmaxFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(-1)); +/// ``` +inline Tensor gumbel_softmax( + const Tensor& logits, + const GumbelSoftmaxFuncOptions& options = {}) { + return detail::gumbel_softmax( + logits, options.tau(), options.hard(), options.dim()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor softmax( + const Tensor& input, + int64_t dim, + std::optional dtype) { + Tensor ret; + + if (dtype == std::nullopt) { + ret = input.softmax(dim); + } else { + ret = input.softmax(dim, dtype); + } + + return ret; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softmax +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::SoftmaxFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softmax(input, F::SoftmaxFuncOptions(1)); +/// ``` +inline Tensor softmax(const Tensor& input, const SoftmaxFuncOptions& options) { + return detail::softmax(input, options.dim(), options.dtype()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor softmin( + const Tensor& input, + int64_t dim, + std::optional dtype) { + Tensor ret; + + if (dtype == std::nullopt) { + ret = (-input).softmax(dim); + } else { + ret = (-input).softmax(dim, dtype); + } + + return ret; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softmin +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::SoftminFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softmin(input, F::SoftminFuncOptions(1)); +/// ``` +inline Tensor softmin(const Tensor& input, const SoftminFuncOptions& options) { + return detail::softmin(input, options.dim(), options.dtype()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor log_softmax( + const Tensor& input, + int64_t dim, + std::optional dtype) { + Tensor ret; + + if (dtype == std::nullopt) { + ret = input.log_softmax(dim); + } else { + ret = input.log_softmax(dim, dtype); + } + + return ret; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.log_softmax +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::LogSoftmaxFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::log_softmax(input, LogSoftmaxFuncOptions(1)); +/// ``` +inline Tensor log_softmax( + const Tensor& input, + const LogSoftmaxFuncOptions& options) { + return detail::log_softmax(input, options.dim(), options.dtype()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor glu(const Tensor& input, int64_t dim) { + TORCH_CHECK( + input.dim() != 0, + "glu does not support scalars because halving size must be even"); + return torch::glu(input, dim); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.glu +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::GLUFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::glu(input, GLUFuncOptions(1)); +/// ``` +inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) { + return detail::glu(input, options.dim()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor gelu(const Tensor& input, const std::string& approximate) { + return torch::gelu(input, approximate); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +inline Tensor gelu(const Tensor& input, const GELUFuncOptions& options = {}) { + return detail::gelu(input, options.approximate()); +} + +// ============================================================================ + +inline Tensor silu(const Tensor& input) { + return torch::silu(input); +} + +// ============================================================================ + +inline Tensor mish(const Tensor& input) { + return torch::mish(input); +} + +// ============================================================================ + +inline Tensor prelu(const Tensor& input, const Tensor& weight) { + return torch::prelu(input, weight); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor relu(Tensor input, bool inplace) { + if (inplace) { + return torch::relu_(input); + } else { + return torch::relu(input); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.relu +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::ReLUFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::relu(x, F::ReLUFuncOptions().inplace(true)); +/// ``` +inline Tensor relu(Tensor input, const ReLUFuncOptions& options = {}) { + return detail::relu(std::move(input), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor relu6(Tensor input, bool inplace) { + if (inplace) { + return torch::relu6_(input); + } else { + return torch::relu6(input); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.relu6 +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::ReLU6FuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::relu6(x, F::ReLU6FuncOptions().inplace(true)); +/// ``` +inline Tensor relu6(Tensor input, const ReLU6FuncOptions& options = {}) { + return detail::relu6(std::move(input), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor rrelu( + Tensor input, + double lower, + double upper, + bool training, + bool inplace) { + if (inplace) { + return torch::rrelu_(input, lower, upper, training); + } else { + return torch::rrelu(input, lower, upper, training); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.rrelu +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::RReLUFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::rrelu(x, F::RReLUFuncOptions().lower(0.1).upper(0.4).inplace(true)); +/// ``` +inline Tensor rrelu(Tensor input, const RReLUFuncOptions& options = {}) { + return detail::rrelu( + std::move(input), + options.lower(), + options.upper(), + options.training(), + options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor celu(Tensor input, double alpha, bool inplace) { + if (inplace) { + return torch::celu_(input, alpha); + } else { + return torch::celu(input, alpha); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.celu +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::CELUFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::celu(x, F::CELUFuncOptions().alpha(0.42).inplace(true)); +/// ``` +inline Tensor celu(Tensor input, const CELUFuncOptions& options = {}) { + return detail::celu(std::move(input), options.alpha(), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor softplus(const Tensor& input, double beta, double threshold) { + return torch::softplus(input, beta, threshold); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softplus +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::SoftplusFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softplus(x, F::SoftplusFuncOptions().beta(0.5).threshold(3.0)); +/// ``` +inline Tensor softplus( + const Tensor& input, + const SoftplusFuncOptions& options = {}) { + return detail::softplus(input, options.beta(), options.threshold()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor softshrink(const Tensor& input, double lambda) { + return torch::softshrink(input, lambda); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.softshrink +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::SoftshrinkFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softshrink(x, F::SoftshrinkFuncOptions(0.42)); +/// ``` +inline Tensor softshrink( + const Tensor& input, + const SoftshrinkFuncOptions& options = {}) { + return detail::softshrink(input, options.lambda()); +} + +// ============================================================================ + +inline Tensor softsign(const Tensor& input) { + return input / (input.abs() + 1); +} + +// ============================================================================ + +inline Tensor tanhshrink(const Tensor& input) { + return input - input.tanh(); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor threshold( + Tensor input, + double threshold, + double value, + bool inplace) { + if (inplace) { + return torch::threshold_(input, threshold, value); + } else { + return torch::threshold(input, threshold, value); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.threshold +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::ThresholdFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::threshold(x, F::ThresholdFuncOptions(0.5, 0.5).inplace(true)); +/// ``` +inline Tensor threshold(Tensor input, const ThresholdFuncOptions& options) { + return detail::threshold( + std::move(input), + options.threshold(), + options.value(), + options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple multi_head_attention_forward( + const Tensor& query, + const Tensor& key, + const Tensor& value, + int64_t embed_dim_to_check, + int64_t num_heads, + const Tensor& in_proj_weight, + const Tensor& in_proj_bias, + const Tensor& bias_k, + const Tensor& bias_v, + bool add_zero_attn, + double dropout_p, + const Tensor& out_proj_weight, + const Tensor& out_proj_bias, + bool training = true, + const Tensor& key_padding_mask = {}, + bool need_weights = true, + const Tensor& attn_mask = {}, + bool use_separate_proj_weight = false, + const Tensor& q_proj_weight = {}, + const Tensor& k_proj_weight = {}, + const Tensor& v_proj_weight = {}, + const Tensor& static_k = {}, + const Tensor& static_v = {}, + bool average_attn_weights = true) { + namespace F = torch::nn::functional; + + const auto query_sizes = query.sizes(); + const auto& tgt_len = query_sizes[0]; + const auto& bsz = query_sizes[1]; + const auto& embed_dim = query_sizes[2]; + TORCH_INTERNAL_ASSERT(embed_dim == embed_dim_to_check); + TORCH_INTERNAL_ASSERT(key.sizes() == value.sizes()); + + const auto head_dim = embed_dim / num_heads; + TORCH_CHECK( + head_dim * num_heads == embed_dim, + "embed_dim must be divisible by num_heads"); + const auto scaling = 1 / std::sqrt(head_dim); + + Tensor q, k, v; + if (!use_separate_proj_weight) { + if (torch::equal(query, key) && torch::equal(key, value)) { + // self-attention + const auto chunks = + F::linear(query, in_proj_weight, in_proj_bias).chunk(3, /*dim=*/-1); + q = chunks[0]; + k = chunks[1]; + v = chunks[2]; + } else if (torch::equal(key, value)) { + // encoder-decoder attention + // This is inline in_proj function with in_proj_weight and in_proj_bias + auto _b = in_proj_bias; + int64_t _start = 0; + auto _end = embed_dim; + auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end); + if (_b.defined()) { + _b = _b.slice(/*dim=*/0, _start, _end); + } + q = F::linear(query, _w, _b); + + if (!key.defined()) { + TORCH_INTERNAL_ASSERT(!value.defined()); + k.reset(); + v.reset(); + } else { + // This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias; + _start = embed_dim; + _w = in_proj_weight.slice(/*dim=*/0, _start); + if (_b.defined()) { + _b = _b.slice(/*dim=*/0, _start); + } + const auto chunks = F::linear(key, _w, _b).chunk(2, /*dim=*/-1); + k = chunks[0]; + v = chunks[1]; + } + } else { + // This is inline in_proj function with in_proj_weight and in_proj_bias + auto _b = in_proj_bias; + int64_t _start = 0; + auto _end = embed_dim; + auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end); + if (_b.defined()) { + _b = _b.slice(/*dim=*/0, _start, _end); + } + q = F::linear(query, _w, _b); + + // This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias; + _start = embed_dim; + _end = embed_dim * 2; + _w = in_proj_weight.slice(/*dim=*/0, _start, _end); + if (_b.defined()) { + _b = _b.slice(/*dim=*/0, _start, _end); + } + k = F::linear(key, _w, _b); + + // This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias; + _start = embed_dim * 2; + _w = in_proj_weight.slice(/*dim=*/0, _start); + if (_b.defined()) { + _b = _b.slice(0, _start); + } + v = F::linear(value, _w, _b); + } + } else { + const auto& q_proj_weight_non_opt = q_proj_weight; + { + const auto sizes = q_proj_weight_non_opt.sizes(); + const auto len1 = sizes[0]; + const auto len2 = sizes[1]; + TORCH_CHECK(len1 == embed_dim && len2 == query.size(-1)); + } + + const auto& k_proj_weight_non_opt = k_proj_weight; + { + const auto sizes = k_proj_weight_non_opt.sizes(); + const auto len1 = sizes[0]; + const auto len2 = sizes[1]; + TORCH_CHECK(len1 == embed_dim && len2 == key.size(-1)); + } + + const auto& v_proj_weight_non_opt = v_proj_weight; + { + const auto sizes = v_proj_weight_non_opt.sizes(); + const auto len1 = sizes[0]; + const auto len2 = sizes[1]; + TORCH_CHECK(len1 == embed_dim && len2 == value.size(-1)); + } + + if (in_proj_bias.defined()) { + q = F::linear( + query, + q_proj_weight_non_opt, + in_proj_bias.slice(/*dim=*/0, 0, embed_dim)); + k = F::linear( + key, + k_proj_weight_non_opt, + in_proj_bias.slice(/*dim=*/0, embed_dim, (embed_dim * 2))); + v = F::linear( + value, + v_proj_weight_non_opt, + in_proj_bias.slice(/*dim=*/0, (embed_dim * 2))); + } else { + q = F::linear(query, q_proj_weight_non_opt, in_proj_bias); + k = F::linear(key, k_proj_weight_non_opt, in_proj_bias); + v = F::linear(value, v_proj_weight_non_opt, in_proj_bias); + } + } + q = q * scaling; + Tensor attn_mask_ = attn_mask; + Tensor key_padding_mask_ = key_padding_mask; + if (bias_k.defined() && bias_v.defined()) { + if (!static_k.defined() && !static_v.defined()) { + k = torch::cat({k, bias_k.repeat({1, bsz, 1})}); + v = torch::cat({v, bias_v.repeat({1, bsz, 1})}); + if (attn_mask_.defined()) { + attn_mask_ = torch::cat( + {attn_mask_, + torch::zeros( + {attn_mask_.size(0), 1}, + at::TensorOptions(attn_mask_.dtype()) + .device(attn_mask_.device()))}, + /*dim=*/1); + } + if (key_padding_mask_.defined()) { + key_padding_mask_ = torch::cat( + {key_padding_mask_, + torch::zeros( + {key_padding_mask_.size(0), 1}, + at::TensorOptions(key_padding_mask_.dtype()) + .device(key_padding_mask_.device()))}, + /*dim=*/1); + } + } else { + TORCH_CHECK(!static_k.defined(), "bias cannot be added to static key."); + TORCH_CHECK(!static_v.defined(), "bias cannot be added to static value."); + } + } else { + TORCH_CHECK(!bias_k.defined()); + TORCH_CHECK(!bias_v.defined()); + } + q = q.contiguous().view({tgt_len, bsz * num_heads, head_dim}).transpose(0, 1); + if (k.defined()) { + k = k.contiguous().view({-1, bsz * num_heads, head_dim}).transpose(0, 1); + } + if (v.defined()) { + v = v.contiguous().view({-1, bsz * num_heads, head_dim}).transpose(0, 1); + } + if (static_k.defined()) { + TORCH_CHECK(static_k.size(0) == bsz * num_heads); + TORCH_CHECK(static_k.size(2) == head_dim); + k = static_k; + } + if (static_v.defined()) { + TORCH_CHECK(static_v.size(0) == bsz * num_heads); + TORCH_CHECK(static_v.size(2) == head_dim); + v = static_v; + } + auto src_len = k.size(1); + if (key_padding_mask_.defined()) { + TORCH_CHECK(key_padding_mask_.size(0) == bsz); + TORCH_CHECK(key_padding_mask_.size(1) == src_len); + } + if (add_zero_attn) { + src_len += 1; + auto k_sizes = k.sizes().vec(); + k_sizes[1] = 1; + k = torch::cat( + {k, + torch::zeros( + k_sizes, at::TensorOptions(k.dtype()).device(k.device()))}, + /*dim=*/1); + auto v_sizes = v.sizes().vec(); + v_sizes[1] = 1; + v = torch::cat( + {v, + torch::zeros( + v_sizes, at::TensorOptions(v.dtype()).device(v.device()))}, + /*dim=*/1); + if (attn_mask_.defined()) { + attn_mask_ = torch::cat( + {attn_mask_, + torch::zeros( + {attn_mask_.size(0), 1}, + at::TensorOptions(attn_mask_.dtype()) + .device(attn_mask_.device()))}, + /*dim=*/1); + } + if (key_padding_mask_.defined()) { + key_padding_mask_ = torch::cat( + {key_padding_mask_, + torch::zeros( + {key_padding_mask_.size(0), 1}, + at::TensorOptions(key_padding_mask_.dtype()) + .device(key_padding_mask_.device()))}, + /*dim=*/1); + } + } + auto attn_output_weights = torch::bmm(q, k.transpose(1, 2)); + TORCH_CHECK( + attn_output_weights.sizes() == + IntArrayRef({bsz * num_heads, tgt_len, src_len})); + if (attn_mask_.defined()) { + attn_mask_ = attn_mask_.unsqueeze(0); + attn_output_weights += attn_mask_; + } + if (key_padding_mask_.defined()) { + attn_output_weights = + attn_output_weights.view({bsz, num_heads, tgt_len, src_len}); + attn_output_weights = AT_DISPATCH_FLOATING_TYPES( + attn_output_weights.scalar_type(), + "attn_output_weights.masked_fill", + [&]() { + return attn_output_weights.masked_fill( + key_padding_mask_.unsqueeze(1).unsqueeze(2), + -std::numeric_limits::infinity()); + }); + attn_output_weights = + attn_output_weights.view({bsz * num_heads, tgt_len, src_len}); + } + attn_output_weights = F::softmax(attn_output_weights, /*options=*/-1); + attn_output_weights = F::dropout( + attn_output_weights, + F::DropoutFuncOptions().p(dropout_p).training(training)); + auto attn_output = torch::bmm(attn_output_weights, v); + TORCH_CHECK( + attn_output.sizes() == IntArrayRef({bsz * num_heads, tgt_len, head_dim})); + attn_output = + attn_output.transpose(0, 1).contiguous().view({tgt_len, bsz, embed_dim}); + attn_output = F::linear(attn_output, out_proj_weight, out_proj_bias); + if (need_weights) { + attn_output_weights = + attn_output_weights.view({bsz, num_heads, tgt_len, src_len}); + if (average_attn_weights) { + // average attention weights over heads + attn_output_weights = attn_output_weights.sum(/*dim=*/1) / num_heads; + } + return std::make_tuple(attn_output, attn_output_weights); + } else { + return std::make_tuple(attn_output, Tensor()); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +inline std::tuple multi_head_attention_forward( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const MultiheadAttentionForwardFuncOptions& options) { + return detail::multi_head_attention_forward( + query, + key, + value, + options.embed_dim_to_check(), + options.num_heads(), + options.in_proj_weight(), + options.in_proj_bias(), + options.bias_k(), + options.bias_v(), + options.add_zero_attn(), + options.dropout_p(), + options.out_proj_weight(), + options.out_proj_bias(), + options.training(), + options.key_padding_mask(), + options.need_weights(), + options.attn_mask(), + options.use_separate_proj_weight(), + options.q_proj_weight(), + options.k_proj_weight(), + options.v_proj_weight(), + options.static_k(), + options.static_v(), + options.average_attn_weights()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/batchnorm.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/batchnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0427f3bb828d5065994a4bb5c7207d241cabfa3c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/batchnorm.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor batch_norm( + const Tensor& input, + const Tensor& running_mean, + const Tensor& running_var, + Tensor weight, + Tensor bias, + bool training, + double momentum, + double eps) { + TORCH_CHECK( + input.dim() >= 2, + "Expected at least 2 input dimensions, but got ", + input.dim()); + if (training) { + auto size = input.sizes(); + int64_t size_prods = size[0]; + for (const auto i : c10::irange(size.size() - 2)) { + size_prods *= size[i + 2]; + } + TORCH_CHECK( + size_prods != 1, + "Expected more than 1 value per channel when training, got input size ", + size); + } + + return torch::batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + at::globalContext().userEnabledCuDNN()); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.batch_norm +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::BatchNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::batch_norm(input, mean, variance, +/// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false)); +/// ``` +inline Tensor batch_norm( + const Tensor& input, + const Tensor& running_mean, + const Tensor& running_var, + const BatchNormFuncOptions& options = {}) { + return detail::batch_norm( + input, + running_mean, + running_var, + options.weight(), + options.bias(), + options.training(), + options.momentum(), + options.eps()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/conv.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/conv.h new file mode 100644 index 0000000000000000000000000000000000000000..1c2b5b73c48dc014504a1f600fcc1e8e26cfbbd2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/conv.h @@ -0,0 +1,297 @@ +#pragma once + +#include +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { + +inline std::string padding_unwrap(enumtype::kValid) { + return "valid"; +} + +inline std::string padding_unwrap(enumtype::kSame) { + return "same"; +} + +template +IntArrayRef padding_unwrap(const ExpandingArray& array) { + return array; +} + +inline Tensor conv1d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + ExpandingArray<1> stride, + const Conv1dFuncOptions::padding_t& padding, + ExpandingArray<1> dilation, + int64_t groups) { + return std::visit( + [&](const auto& pad) { + return torch::conv1d( + input, weight, bias, stride, padding_unwrap(pad), dilation, groups); + }, + padding); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv1d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::Conv1dFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1)); +/// ``` +inline Tensor conv1d( + const Tensor& input, + const Tensor& weight, + const Conv1dFuncOptions& options = {}) { + return detail::conv1d( + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.dilation(), + options.groups()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor conv2d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + ExpandingArray<2> stride, + const Conv2dFuncOptions::padding_t& padding, + ExpandingArray<2> dilation, + int64_t groups) { + return std::visit( + [&](const auto& pad) { + return torch::conv2d( + input, weight, bias, stride, padding_unwrap(pad), dilation, groups); + }, + padding); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv2d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::Conv2dFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1)); +/// ``` +inline Tensor conv2d( + const Tensor& input, + const Tensor& weight, + const Conv2dFuncOptions& options = {}) { + return detail::conv2d( + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.dilation(), + options.groups()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor conv3d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + ExpandingArray<3> stride, + const Conv3dFuncOptions::padding_t& padding, + ExpandingArray<3> dilation, + int64_t groups) { + return std::visit( + [&](const auto& pad) { + return torch::conv3d( + input, weight, bias, stride, padding_unwrap(pad), dilation, groups); + }, + padding); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv3d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::Conv3dFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1)); +/// ``` +inline Tensor conv3d( + const Tensor& input, + const Tensor& weight, + const Conv3dFuncOptions& options = {}) { + return detail::conv3d( + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.dilation(), + options.groups()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor conv_transpose1d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef output_padding, + int64_t groups, + IntArrayRef dilation) { + return torch::conv_transpose1d( + input, weight, bias, stride, padding, output_padding, groups, dilation); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose1d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::ConvTranspose1dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1)); +/// ``` +inline Tensor conv_transpose1d( + const Tensor& input, + const Tensor& weight, + const ConvTranspose1dFuncOptions& options = {}) { + return detail::conv_transpose1d( + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor conv_transpose2d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef output_padding, + int64_t groups, + IntArrayRef dilation) { + return torch::conv_transpose2d( + input, weight, bias, stride, padding, output_padding, groups, dilation); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose2d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::ConvTranspose2dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); +/// ``` +inline Tensor conv_transpose2d( + const Tensor& input, + const Tensor& weight, + const ConvTranspose2dFuncOptions& options = {}) { + return detail::conv_transpose2d( + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor conv_transpose3d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef output_padding, + int64_t groups, + IntArrayRef dilation) { + return torch::conv_transpose3d( + input, weight, bias, stride, padding, output_padding, groups, dilation); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.conv_transpose3d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::ConvTranspose3dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1)); +/// ``` +inline Tensor conv_transpose3d( + const Tensor& input, + const Tensor& weight, + const ConvTranspose3dFuncOptions& options = {}) { + return detail::conv_transpose3d( + input, + weight, + options.bias(), + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/distance.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/distance.h new file mode 100644 index 0000000000000000000000000000000000000000..c5cb133aa609bd651d380d8ef5bf70368b37747c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/distance.h @@ -0,0 +1,84 @@ +#pragma once + +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor cosine_similarity( + const Tensor& x1, + const Tensor& x2, + int64_t dim, + double eps) { + return torch::cosine_similarity(x1, x2, dim, eps); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.cosine_similarity +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::CosineSimilarityFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::cosine_similarity(input1, input2, +/// F::CosineSimilarityFuncOptions().dim(1)); +/// ``` +inline Tensor cosine_similarity( + const Tensor& x1, + const Tensor& x2, + const CosineSimilarityFuncOptions& options = {}) { + return detail::cosine_similarity(x1, x2, options.dim(), options.eps()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor pairwise_distance( + const Tensor& x1, + const Tensor& x2, + double p, + double eps, + bool keepdim) { + return torch::pairwise_distance(x1, x2, p, eps, keepdim); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.pairwise_distance +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::PairwiseDistanceFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pairwise_distance(input1, input2, F::PairwiseDistanceFuncOptions().p(1)); +/// ``` +inline Tensor pairwise_distance( + const Tensor& x1, + const Tensor& x2, + const PairwiseDistanceFuncOptions& options = {}) { + return detail::pairwise_distance( + x1, x2, options.p(), options.eps(), options.keepdim()); +} + +// ============================================================================ + +/// Computes the p-norm distance between every pair of row vectors in the input. +/// This function will be faster if the rows are contiguous. +inline Tensor pdist(const Tensor& input, double p = 2.0) { + return torch::pdist(input, p); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/dropout.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/dropout.h new file mode 100644 index 0000000000000000000000000000000000000000..d365ff84004771f7d6d6f2f4d0e78b4624e5d71a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/dropout.h @@ -0,0 +1,230 @@ +#pragma once + +#include + +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { + +inline Tensor dropout(Tensor input, double p, bool training, bool inplace) { + TORCH_CHECK( + p >= 0. && p <= 1., + "dropout probability has to be between 0 and 1, but got ", + p); + if (inplace) { + return torch::dropout_(input, p, training); + } else { + return torch::dropout(input, p, training); + } +} + +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::DropoutFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::dropout(input, F::DropoutFuncOptions().p(0.5)); +/// ``` +inline Tensor dropout(Tensor input, const DropoutFuncOptions& options = {}) { + return detail::dropout( + std::move(input), options.p(), options.training(), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { + +template +inline Tensor _dropoutNd_helper( + Tensor input, + double p, + bool training, + bool inplace, + const char* fn_name) { + TORCH_CHECK( + p >= 0. && p <= 1., + "dropout probability has to be between 0 and 1, but got ", + p); + + auto inp_dim = input.dim(); + auto is_batched = inp_dim == batched_dim; + if (!is_batched) { + if (inplace) { + input = input.unsqueeze_(0); + } else { + input = input.unsqueeze(0); + } + } + + Tensor result; + if (inplace) { + result = torch::feature_dropout_(input, p, training); + } else { + result = torch::feature_dropout(input, p, training); + } + + if (!is_batched) { + if (inplace) { + result = result.squeeze_(0); + } else { + result = result.squeeze(0); + } + } + return result; +} + +inline Tensor dropout2d(Tensor input, double p, bool training, bool inplace) { + return _dropoutNd_helper<3, 4>( + std::move(input), p, training, inplace, "dropout2d"); +} + +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout2d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::Dropout2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::dropout2d(input, F::Dropout2dFuncOptions().p(0.5)); +/// ``` +inline Tensor dropout2d( + Tensor input, + const Dropout2dFuncOptions& options = {}) { + return detail::dropout2d( + std::move(input), options.p(), options.training(), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { + +inline Tensor dropout3d(Tensor input, double p, bool training, bool inplace) { + return _dropoutNd_helper<4, 5>( + std::move(input), p, training, inplace, "dropout3d"); +} + +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.dropout3d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::Dropout3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::dropout3d(input, F::Dropout3dFuncOptions().p(0.5)); +/// ``` +inline Tensor dropout3d( + Tensor input, + const Dropout3dFuncOptions& options = {}) { + return detail::dropout3d( + std::move(input), options.p(), options.training(), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { + +inline Tensor alpha_dropout( + Tensor input, + double p, + bool training, + bool inplace) { + if (p < 0. || p > 1.) { + TORCH_CHECK( + false, "dropout probability has to be between 0 and 1, but got ", p); + } + return inplace ? torch::alpha_dropout_(input, p, training) + : torch::alpha_dropout(input, p, training); +} + +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.alpha_dropout +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::AlphaDropoutFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::alpha_dropout(input, +/// F::AlphaDropoutFuncOptions().p(0.5).training(false)); +/// ``` +inline Tensor alpha_dropout( + Tensor input, + const AlphaDropoutFuncOptions& options = {}) { + return detail::alpha_dropout( + std::move(input), options.p(), options.training(), options.inplace()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { + +inline Tensor feature_alpha_dropout( + Tensor input, + double p, + bool training, + bool inplace) { + if (p < 0. || p > 1.) { + TORCH_CHECK( + false, "dropout probability has to be between 0 and 1, but got ", p); + } + return inplace ? torch::feature_alpha_dropout_(input, p, training) + : torch::feature_alpha_dropout(input, p, training); +} + +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.feature_alpha_dropout +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::FeatureAlphaDropoutFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::feature_alpha_dropout(input, +/// F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false)); +/// ``` +inline Tensor feature_alpha_dropout( + Tensor input, + const FeatureAlphaDropoutFuncOptions& options = {}) { + return detail::feature_alpha_dropout( + std::move(input), options.p(), options.training(), options.inplace()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/embedding.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/embedding.h new file mode 100644 index 0000000000000000000000000000000000000000..fb8aa8d45b2b97e878730274e421c9c0397a68d0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/embedding.h @@ -0,0 +1,206 @@ +#pragma once + +#include + +namespace torch::nn::functional { + +inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) { + return torch::one_hot(tensor, num_classes); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline void _no_grad_embedding_renorm_( + Tensor weight, + const Tensor& input, + float max_norm, + float norm_type) { + torch::NoGradGuard no_grad; + torch::embedding_renorm_(weight, input, max_norm, norm_type); +} + +inline Tensor embedding( + const Tensor& input, + const Tensor& weight, + std::optional padding_idx, + std::optional max_norm, + double norm_type, + bool scale_grad_by_freq, + bool sparse) { + auto input_ = input; + + if (padding_idx != std::nullopt) { + if (*padding_idx > 0) { + TORCH_CHECK( + *padding_idx < weight.size(0), + "Padding_idx must be within num_embeddings"); + } else if (*padding_idx < 0) { + TORCH_CHECK( + *padding_idx >= -weight.size(0), + "Padding_idx must be within num_embedding"); + padding_idx = weight.size(0) + *padding_idx; + } + } else { + padding_idx = -1; + } + + if (max_norm != std::nullopt) { + input_ = input_.contiguous(); + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type); + } + return torch::embedding( + weight, input_, *padding_idx, scale_grad_by_freq, sparse); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.embedding +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::EmbeddingFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::embedding(input, weight, +/// F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// ``` +inline Tensor embedding( + const Tensor& input, + const Tensor& weight, + const EmbeddingFuncOptions& options = {}) { + return detail::embedding( + input, + weight, + options.padding_idx(), + options.max_norm(), + options.norm_type(), + options.scale_grad_by_freq(), + options.sparse()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor embedding_bag( + const Tensor& input, + const Tensor& weight, + const Tensor& offsets, + std::optional max_norm, + double norm_type, + bool scale_grad_by_freq, + EmbeddingBagMode mode, + bool sparse, + const Tensor& per_sample_weights, + bool include_last_offset, + std::optional padding_idx) { + auto input_ = input; + auto offsets_ = offsets; + auto per_sample_weights_ = per_sample_weights; + TORCH_CHECK( + !per_sample_weights_.defined() || + input_.sizes() == per_sample_weights_.sizes(), + "embedding_bag: If per_sample_weights (", + per_sample_weights_.sizes(), + ") is not null, then it must have the same shape as the input (", + input_.sizes(), + ")"); + if (input_.dim() == 2) { + TORCH_CHECK( + !offsets_.defined(), + "If input is 2D, then offsets has to be null, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type Tensor"); + offsets_ = torch::arange( + 0, + input_.numel(), + input_.size(1), + torch::TensorOptions().dtype(torch::kLong).device(input_.device())); + input_ = input_.reshape(-1); + if (per_sample_weights_.defined()) { + per_sample_weights_ = per_sample_weights_.reshape(-1); + } + } else if (input_.dim() == 1) { + TORCH_CHECK( + offsets_.defined(), "offsets has to be a 1D Tensor but got null"); + TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor"); + } else { + TORCH_CHECK( + false, + "input has to be 1D or 2D Tensor, but got Tensor of dimension ", + input_.dim()); + } + + int mode_enum = 0; + if (std::holds_alternative(mode)) { + mode_enum = 0; + } else if (std::holds_alternative(mode)) { + mode_enum = 1; + } else if (std::holds_alternative(mode)) { + mode_enum = 2; + TORCH_CHECK( + !scale_grad_by_freq, + "max mode does not support scaling the gradient by the frequency"); + TORCH_CHECK(!sparse, "max mode does not support sparse weights"); + } else { + TORCH_CHECK(false, "mode has to be one of sum, mean or max"); + } + + if (max_norm != std::nullopt) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type); + } + + TORCH_CHECK( + !per_sample_weights_.defined() || std::get_if(&mode), + "embedding_bag: per_sample_weights was not null. ", + "per_sample_weights is only supported for mode='kSum' (got mode='", + torch::enumtype::get_enum_name(mode), + "').Please open a feature request on GitHub."); + + return std::get<0>(torch::embedding_bag( + weight, + input_, + offsets_, + scale_grad_by_freq, + mode_enum, + sparse, + per_sample_weights_, + include_last_offset, + padding_idx)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.embedding_bag +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::EmbeddingBagFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::embedding_bag(input, weight, +/// F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets)); +/// ``` +inline Tensor embedding_bag( + const Tensor& input, + const Tensor& weight, + const EmbeddingBagFuncOptions& options = {}) { + return detail::embedding_bag( + input, + weight, + options.offsets(), + options.max_norm(), + options.norm_type(), + options.scale_grad_by_freq(), + options.mode(), + options.sparse(), + options.per_sample_weights(), + options.include_last_offset(), + options.padding_idx()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/fold.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/fold.h new file mode 100644 index 0000000000000000000000000000000000000000..23b19d0bb8d58a3ec863311b8412257ad901a237 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/fold.h @@ -0,0 +1,98 @@ +#pragma once + +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor fold( + const Tensor& input, + ExpandingArray<2> output_size, + ExpandingArray<2> kernel_size, + ExpandingArray<2> dilation, + ExpandingArray<2> padding, + ExpandingArray<2> stride) { + if (input.dim() == 3 || input.dim() == 2) { + return torch::col2im( + input, output_size, kernel_size, dilation, padding, stride); + } else { + TORCH_CHECK( + false, + "Input Error: Only unbatched (2D) or batched (3D) input Tensors are supported " + "(got ", + input.dim(), + "D)"); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.fold +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::FoldFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2})); +/// ``` +inline Tensor fold(const Tensor& input, const FoldFuncOptions& options) { + return detail::fold( + input, + options.output_size(), + options.kernel_size(), + options.dilation(), + options.padding(), + options.stride()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor unfold( + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> dilation, + ExpandingArray<2> padding, + ExpandingArray<2> stride) { + if (input.dim() == 4) { + return torch::im2col(input, kernel_size, dilation, padding, stride); + } else { + TORCH_CHECK( + false, + "Input Error: Only 4D input Tensors are supported " + "(got ", + input.dim(), + "D)"); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.unfold +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::UnfoldFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2)); +/// ``` +inline Tensor unfold(const Tensor& input, const UnfoldFuncOptions& options) { + return detail::unfold( + input, + options.kernel_size(), + options.dilation(), + options.padding(), + options.stride()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/instancenorm.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/instancenorm.h new file mode 100644 index 0000000000000000000000000000000000000000..92f96946503190e6fb075192b45227cf663ee305 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/instancenorm.h @@ -0,0 +1,59 @@ +#pragma once + +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor instance_norm( + const Tensor& input, + const Tensor& running_mean, + const Tensor& running_var, + const Tensor& weight, + const Tensor& bias, + bool use_input_stats, + double momentum, + double eps) { + return torch::instance_norm( + input, + weight, + bias, + running_mean, + running_var, + use_input_stats, + momentum, + eps, + at::globalContext().userEnabledCuDNN()); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.instance_norm +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::InstanceNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::instance_norm(input, +/// F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5)); +/// ``` +inline Tensor instance_norm( + const Tensor& input, + const InstanceNormFuncOptions& options = {}) { + return detail::instance_norm( + input, + options.running_mean(), + options.running_var(), + options.weight(), + options.bias(), + options.use_input_stats(), + options.momentum(), + options.eps()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/linear.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/linear.h new file mode 100644 index 0000000000000000000000000000000000000000..4d9e7fe6d4b7a94a6d9ed5565ec7db458272af4c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/linear.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +namespace torch::nn::functional { + +inline Tensor bilinear( + const Tensor& input1, + const Tensor& input2, + const Tensor& weight, + const Tensor& bias = Tensor()) { + return torch::bilinear(input1, input2, weight, bias); +} + +// ============================================================================ + +inline Tensor linear( + const Tensor& input, + const Tensor& weight, + const Tensor& bias = {}) { + if (input.dim() == 2 && bias.defined()) { + // fused op is marginally faster + return torch::addmm(bias, input, weight.t()); + } else { + auto output = input.matmul(weight.t()); + if (bias.defined()) { + output += bias; + } + return output; + } +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/loss.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/loss.h new file mode 100644 index 0000000000000000000000000000000000000000..b81bf47cf54537d8435c69ac8187c0779cd97334 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/loss.h @@ -0,0 +1,1039 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor l1_loss( + const Tensor& input, + const Tensor& target, + L1LossFuncOptions::reduction_t reduction) { + return torch::l1_loss(input, target, enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.l1_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::L1LossFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::l1_loss(input, target, F::L1LossFuncOptions(torch::kNone)); +/// ``` +inline Tensor l1_loss( + const Tensor& input, + const Tensor& target, + const L1LossFuncOptions& options = {}) { + return detail::l1_loss(input, target, options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor kl_div( + const Tensor& input, + const Tensor& target, + KLDivFuncOptions::reduction_t reduction, + bool log_target = false) { + torch::Reduction::Reduction reduction_enum{}; + + if (std::holds_alternative(reduction)) { + TORCH_WARN( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release."); + } + + // special case for batchmean + if (std::holds_alternative(reduction)) { + reduction_enum = torch::Reduction::Sum; + } else { + reduction_enum = enumtype::reduction_get_enum(reduction); + } + + auto reduced = torch::kl_div(input, target, reduction_enum, log_target); + + if (std::holds_alternative(reduction) && + input.dim() != 0) { + reduced = reduced / input.sizes()[0]; + } + + return reduced; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.kl_div +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::KLDivFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::kl_div(input, target, +/// F::KLDivFuncOptions.reduction(torch::kNone).log_target(false)); +/// ``` +inline Tensor kl_div( + const Tensor& input, + const Tensor& target, + const KLDivFuncOptions& options = {}) { + return detail::kl_div( + input, target, options.reduction(), options.log_target()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor mse_loss( + const Tensor& input, + const Tensor& target, + MSELossFuncOptions::reduction_t reduction) { + if (!(target.sizes() == input.sizes())) { + TORCH_WARN( + "Using a target size (", + target.sizes(), + ") that is different to the input size (", + input.sizes(), + "). ", + "This will likely lead to incorrect results due to broadcasting. ", + "Please ensure they have the same size."); + } + std::vector broadcast_tensors = + torch::broadcast_tensors({input, target}); + auto expanded_input = broadcast_tensors[0]; + auto expanded_target = broadcast_tensors[1]; + return torch::mse_loss( + expanded_input, expanded_target, enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.mse_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::MSELossFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::mse_loss(input, target, F::MSELossFuncOptions(torch::kNone)); +/// ``` +inline Tensor mse_loss( + const Tensor& input, + const Tensor& target, + const MSELossFuncOptions& options = {}) { + return detail::mse_loss(input, target, options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor binary_cross_entropy( + const Tensor& input, + const Tensor& target, + const Tensor& weight, + BinaryCrossEntropyFuncOptions::reduction_t reduction) { + auto reduction_enum = enumtype::reduction_get_enum(reduction); + + if (target.sizes() != input.sizes()) { + TORCH_CHECK( + false, + "Using a target size (", + target.sizes(), + ") ", + "that is different to the input size (", + input.sizes(), + ") is deprecated. ", + "Please ensure they have the same size."); + } + + auto weight_ = weight; + if (weight_.defined()) { + auto new_size = at::infer_size(target.sizes(), weight_.sizes()); + weight_ = weight_.expand(new_size); + } + + return torch::binary_cross_entropy(input, target, weight_, reduction_enum); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.binary_cross_entropy +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::BinaryCrossEntropyFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::binary_cross_entropy(input, target, +/// F::BinaryCrossEntropyFuncOptions().weight(weight)); +/// ``` +inline Tensor binary_cross_entropy( + const Tensor& input, + const Tensor& target, + const BinaryCrossEntropyFuncOptions& options = {}) { + return detail::binary_cross_entropy( + input, target, options.weight(), options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor hinge_embedding_loss( + const Tensor& input, + const Tensor& target, + double margin, + HingeEmbeddingLossFuncOptions::reduction_t reduction) { + return torch::hinge_embedding_loss( + input, target, margin, enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.hinge_embedding_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::HingeEmbeddingLossFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::hinge_embedding_loss(input, target, +/// F::HingeEmbeddingLossFuncOptions().margin(2)); +/// ``` +inline Tensor hinge_embedding_loss( + const Tensor& input, + const Tensor& target, + const HingeEmbeddingLossFuncOptions& options = {}) { + return detail::hinge_embedding_loss( + input, target, options.margin(), options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor multi_margin_loss( + const Tensor& input, + const Tensor& target, + int64_t p, + double margin, + const Tensor& weight, + MultiMarginLossFuncOptions::reduction_t reduction) { + TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported"); + if (weight.defined()) { + TORCH_CHECK(weight.dim() == 1, "weight must be one-dimensional"); + } + + return torch::multi_margin_loss( + input, + target, + p, + margin, + weight, + enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multi_margin_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::MultiMarginLossFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::multi_margin_loss(input, target, +/// F::MultiMarginLossFuncOptions().margin(2).weight(weight)); +/// ``` +inline Tensor multi_margin_loss( + const Tensor& input, + const Tensor& target, + const MultiMarginLossFuncOptions& options = {}) { + return detail::multi_margin_loss( + input, + target, + options.p(), + options.margin(), + options.weight(), + options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor cosine_embedding_loss( + const Tensor& input1, + const Tensor& input2, + const Tensor& target, + double margin, + CosineEmbeddingLossFuncOptions::reduction_t reduction) { + return torch::cosine_embedding_loss( + input1, input2, target, margin, enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.cosine_embedding_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::CosineEmbeddingLossFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::cosine_embedding_loss(input1, input2, target, +/// F::CosineEmbeddingLossFuncOptions().margin(0.5)); +/// ``` +inline Tensor cosine_embedding_loss( + const Tensor& input1, + const Tensor& input2, + const Tensor& target, + const CosineEmbeddingLossFuncOptions& options = {}) { + return detail::cosine_embedding_loss( + input1, input2, target, options.margin(), options.reduction()); +} + +// ============================================================================ + +inline Tensor _smooth_l1_loss( + const Tensor& input, + const Tensor& target, + double beta = 1.) { + auto t = torch::abs(input - target); + return torch::where(t < beta, 0.5 * torch::pow(t, 2) / beta, t - 0.5 * beta); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor smooth_l1_loss( + const Tensor& input, + const Tensor& target, + SmoothL1LossFuncOptions::reduction_t reduction, + std::optional beta_opt = std::nullopt) { + if (target.sizes() != input.sizes()) { + TORCH_WARN( + "Using a target size (", + target.sizes(), + ") that is different to the input size (", + input.sizes(), + "). ", + "This will likely lead to incorrect results due to broadcasting. ", + "Please ensure they have the same size."); + } + double beta = beta_opt.value_or(1.0); + + std::vector expanded_tensors = + torch::broadcast_tensors({input, target}); + return torch::smooth_l1_loss( + expanded_tensors[0], + expanded_tensors[1], + enumtype::reduction_get_enum(reduction), + beta); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.smooth_l1_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::SmoothL1LossFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::smooth_l1_loss(input, target, F::SmoothL1LossFuncOptions(torch::kNone)); +/// ``` +inline Tensor smooth_l1_loss( + const Tensor& input, + const Tensor& target, + const SmoothL1LossFuncOptions& options = {}) { + return detail::smooth_l1_loss( + input, target, options.reduction(), options.beta()); +} + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.smooth_l1_loss +/// about the exact behavior of this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::smooth_l1_loss(input, target, /*options=*/torch::kNone, /*beta=*/0.5); +/// ``` +inline Tensor smooth_l1_loss( + const Tensor& input, + const Tensor& target, + const SmoothL1LossFuncOptions& options, + double beta) { + TORCH_CHECK( + !options.beta().has_value(), + "expected beta not to be provided in 'options', but got ", + options.beta()); + return detail::smooth_l1_loss(input, target, options.reduction(), beta); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor huber_loss( + const Tensor& input, + const Tensor& target, + HuberLossFuncOptions::reduction_t reduction, + double delta = 1.) { + if (target.sizes() != input.sizes()) { + TORCH_WARN( + "Using a target size (", + target.sizes(), + ") that is different to the input size (", + input.sizes(), + "). ", + "This will likely lead to incorrect results due to broadcasting. ", + "Please ensure they have the same size."); + } + + std::vector expanded_tensors = + torch::broadcast_tensors({input, target}); + return torch::huber_loss( + expanded_tensors[0], + expanded_tensors[1], + enumtype::reduction_get_enum(reduction), + delta); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.huber_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::HuberLossFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::huber_loss(input, target, +/// F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5)); +/// ``` +inline Tensor huber_loss( + const Tensor& input, + const Tensor& target, + const HuberLossFuncOptions& options = {}) { + return detail::huber_loss( + input, target, options.reduction(), options.delta()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor multilabel_margin_loss( + const Tensor& input, + const Tensor& target, + MultilabelMarginLossFuncOptions::reduction_t reduction) { + return torch::multilabel_margin_loss( + input, target, enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multilabel_margin_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::MultilabelMarginLossFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::multilabel_margin_loss(input, target, +/// F::MultilabelMarginLossFuncOptions(torch::kNone)); +/// ``` +inline Tensor multilabel_margin_loss( + const Tensor& input, + const Tensor& target, + const MultilabelMarginLossFuncOptions& options = {}) { + return detail::multilabel_margin_loss(input, target, options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor soft_margin_loss( + const Tensor& input, + const Tensor& target, + SoftMarginLossFuncOptions::reduction_t reduction) { + return torch::soft_margin_loss( + input, target, enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.soft_margin_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::SoftMarginLossFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::soft_margin_loss(input, target, +/// F::SoftMarginLossFuncOptions(torch::kNone)); +/// ``` +inline Tensor soft_margin_loss( + const Tensor& input, + const Tensor& target, + const SoftMarginLossFuncOptions& options = {}) { + return detail::soft_margin_loss(input, target, options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor multilabel_soft_margin_loss( + const Tensor& input, + const Tensor& target, + const Tensor& weight, + MultilabelSoftMarginLossFuncOptions::reduction_t reduction) { + auto loss = + -(target * torch::log_sigmoid(input) + + (1 - target) * torch::log_sigmoid(-input)); + if (weight.defined()) { + loss = loss * weight; + } + + auto class_dim = input.dim() - 1; + auto C = input.size(class_dim); + loss = loss.sum(class_dim) / C; // only return N loss values + + Tensor ret; + + if (std::holds_alternative(reduction)) { + ret = loss; + } else if (std::holds_alternative(reduction)) { + ret = loss.mean(); + } else if (std::holds_alternative(reduction)) { + ret = loss.sum(); + } else { + ret = input; + TORCH_INTERNAL_ASSERT( + false, enumtype::get_enum_name(reduction), " is not valid"); + } + return ret; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.multilabel_soft_margin_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::MultilabelSoftMarginLossFuncOptions` class to learn +/// what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::multilabel_soft_margin_loss(input, target, +/// F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight)); +/// ``` +inline Tensor multilabel_soft_margin_loss( + const Tensor& input, + const Tensor& target, + const MultilabelSoftMarginLossFuncOptions& options = {}) { + return detail::multilabel_soft_margin_loss( + input, target, options.weight(), options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor triplet_margin_loss( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative, + double margin, + double p, + double eps, + bool swap, + TripletMarginLossFuncOptions::reduction_t reduction) { + return torch::triplet_margin_loss( + anchor, + positive, + negative, + margin, + p, + eps, + swap, + enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.triplet_margin_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::TripletMarginLossFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::triplet_margin_loss(anchor, positive, negative, +/// F::TripletMarginLossFuncOptions().margin(1.0)); +/// ``` +inline Tensor triplet_margin_loss( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative, + const TripletMarginLossFuncOptions& options = {}) { + return detail::triplet_margin_loss( + anchor, + positive, + negative, + options.margin(), + options.p(), + options.eps(), + options.swap(), + options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor triplet_margin_with_distance_loss( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative, + std::optional + distance_function, + double margin, + bool swap, + TripletMarginWithDistanceLossFuncOptions::reduction_t reduction) { + Tensor dist_pos, dist_neg; + if (distance_function.has_value()) { + auto distance_function_impl = distance_function.value(); + dist_pos = distance_function_impl(anchor, positive); + dist_neg = distance_function_impl(anchor, negative); + } else { + dist_pos = pairwise_distance(anchor, positive); + dist_neg = pairwise_distance(anchor, negative); + } + + if (swap) { + Tensor dist_swap; + if (distance_function.has_value()) { + dist_swap = distance_function.value()(positive, negative); + } else { + dist_swap = pairwise_distance(positive, negative); + } + dist_neg = torch::min(dist_neg, dist_swap); + } + + auto loss = torch::clamp_min(dist_pos - dist_neg + margin, 0); + + Tensor ret; + if (std::holds_alternative(reduction)) { + ret = loss; + } else if (std::holds_alternative(reduction)) { + ret = loss.mean(); + } else if (std::holds_alternative(reduction)) { + ret = loss.sum(); + } else { + ret = anchor; + TORCH_INTERNAL_ASSERT( + false, enumtype::get_enum_name(reduction), " is not valid"); + } + return ret; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.triplet_margin_with_distance_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::TripletMarginWithDistanceLossFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::triplet_margin_with_distance_loss(anchor, positive, negative, +/// F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); +/// ``` +inline Tensor triplet_margin_with_distance_loss( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative, + const TripletMarginWithDistanceLossFuncOptions& options = {}) { + return detail::triplet_margin_with_distance_loss( + anchor, + positive, + negative, + options.distance_function(), + options.margin(), + options.swap(), + options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor ctc_loss( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths, + int64_t blank, + CTCLossFuncOptions::reduction_t reduction, + bool zero_infinity) { + return torch::ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + blank, + enumtype::reduction_get_enum(reduction), + zero_infinity); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.ctc_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::CTCLossFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::ctc_loss(log_probs, targets, input_lengths, target_lengths, +/// F::CTCLossFuncOptions().reduction(torch::kNone)); +/// ``` +inline Tensor ctc_loss( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths, + const CTCLossFuncOptions& options = {}) { + return detail::ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + options.blank(), + options.reduction(), + options.zero_infinity()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor poisson_nll_loss( + const Tensor& input, + const Tensor& target, + bool log_input, + bool full, + double eps, + PoissonNLLLossFuncOptions::reduction_t reduction) { + return torch::poisson_nll_loss( + input, + target, + log_input, + full, + eps, + enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.poisson_nll_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::PoissonNLLLossFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::poisson_nll_loss(input, target, +/// F::PoissonNLLLossFuncOptions().reduction(torch::kNone)); +/// ``` +inline Tensor poisson_nll_loss( + const Tensor& input, + const Tensor& target, + const PoissonNLLLossFuncOptions& options = {}) { + return detail::poisson_nll_loss( + input, + target, + options.log_input(), + options.full(), + options.eps(), + options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor margin_ranking_loss( + const Tensor& input1, + const Tensor& input2, + const Tensor& target, + double margin, + MarginRankingLossFuncOptions::reduction_t reduction) { + TORCH_CHECK( + input1.dim() == input2.dim() && input1.dim() == target.dim(), + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + "input1: ", + input1.sizes(), + ", input2: ", + input2.sizes(), + ", target: ", + target.sizes()); + return torch::margin_ranking_loss( + input1, input2, target, margin, enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.margin_ranking_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::MarginRankingLossFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::margin_ranking_loss(input1, input2, target, +/// F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum)); +/// ``` +inline Tensor margin_ranking_loss( + const Tensor& input1, + const Tensor& input2, + const Tensor& target, + const MarginRankingLossFuncOptions& options = {}) { + return detail::margin_ranking_loss( + input1, input2, target, options.margin(), options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor nll_loss( + const Tensor& input, + const Tensor& target, + const Tensor& weight, + int64_t ignore_index, + const NLLLossFuncOptions::reduction_t& reduction) { + if (input.dim() < 2) { + TORCH_CHECK(false, "Expected 2 or more dimensions (got ", input.dim(), ")"); + } + + if (input.sizes()[0] != target.sizes()[0]) { + TORCH_CHECK( + false, + "Expected input batch_size (", + input.sizes()[0], + ") to match target batch_size (", + target.sizes()[0], + ")."); + } + + return torch::nll_loss_nd( + input, + target, + weight, + enumtype::reduction_get_enum(reduction), + ignore_index); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.nll_loss +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::NLLLossFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::nll_loss(input, target, +/// F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +inline Tensor nll_loss( + const Tensor& input, + const Tensor& target, + const NLLLossFuncOptions& options = {}) { + return detail::nll_loss( + input, + target, + options.weight(), + options.ignore_index(), + options.reduction()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor cross_entropy( + const Tensor& input, + const Tensor& target, + const Tensor& weight, + int64_t ignore_index, + CrossEntropyFuncOptions::reduction_t reduction, + double label_smoothing) { + return torch::cross_entropy_loss( + input, + target, + weight, + enumtype::reduction_get_enum(reduction), + ignore_index, + label_smoothing); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.cross_entropy +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::CrossEntropyFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::cross_entropy(input, target, +/// F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +inline Tensor cross_entropy( + const Tensor& input, + const Tensor& target, + const CrossEntropyFuncOptions& options = {}) { + return detail::cross_entropy( + input, + target, + options.weight(), + options.ignore_index(), + options.reduction(), + options.label_smoothing()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor binary_cross_entropy_with_logits( + const Tensor& input, + const Tensor& target, + const Tensor& weight, + BinaryCrossEntropyWithLogitsFuncOptions::reduction_t reduction, + const Tensor& pos_weight) { + TORCH_CHECK( + target.sizes() == input.sizes(), + "Target size (", + target.sizes(), + ") must be the same as input size (", + input.sizes(), + ")"); + + return torch::binary_cross_entropy_with_logits( + input, + target, + weight, + pos_weight, + enumtype::reduction_get_enum(reduction)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.binary_cross_entropy_with_logits +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::BinaryCrossEntropyWithLogitsFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::binary_cross_entropy_with_logits(input, target, +/// F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum)); +/// ``` +inline Tensor binary_cross_entropy_with_logits( + const Tensor& input, + const Tensor& target, + const BinaryCrossEntropyWithLogitsFuncOptions& options = {}) { + return detail::binary_cross_entropy_with_logits( + input, + target, + options.weight(), + options.reduction(), + options.pos_weight()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/normalization.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..3df0189890864f4b18eb7ae8321625733a7664e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/normalization.h @@ -0,0 +1,207 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor normalize( + const Tensor& input, + double p, + int64_t dim, + double eps, + std::optional out) { + if (out == std::nullopt) { + auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); + return input / denom; + } else { + auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); + return torch::div_out(*out, input, denom); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.normalize +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::NormalizeFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); +/// ``` +inline Tensor normalize( + const Tensor& input, + NormalizeFuncOptions options = {}) { + return detail::normalize( + input, options.p(), options.dim(), options.eps(), options.out()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor layer_norm( + const Tensor& input, + const std::vector& normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps) { + return torch::layer_norm(input, normalized_shape, weight, bias, eps); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.layer_norm +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::LayerNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5)); +/// ``` +inline Tensor layer_norm( + const Tensor& input, + const LayerNormFuncOptions& options) { + return detail::layer_norm( + input, + options.normalized_shape(), + options.weight(), + options.bias(), + options.eps()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor local_response_norm( + const Tensor& input, + int64_t size, + double alpha, + double beta, + double k) { + auto dim = input.dim(); + TORCH_CHECK( + dim >= 3, + "Expected 3D or higher dimensionality input (got ", + dim, + " dimensions)"); + auto div = input.mul(input).unsqueeze(1); + if (dim == 3) { + div = detail::pad( + div, + /*pad=*/{0, 0, size / 2, (size - 1) / 2}, + /*mode=*/torch::kConstant, + /*value=*/0); + div = detail::avg_pool2d( + div, + /*kernel_size=*/{size, 1}, + /*stride=*/1, + /*padding=*/0, + /*ceil_mode=*/false, + /*count_include_pad=*/true, + /*divisor_override=*/std::nullopt) + .squeeze(1); + } else { + auto sizes = input.sizes(); + div = div.view({sizes[0], 1, sizes[1], sizes[2], -1}); + div = detail::pad( + div, + /*pad=*/{0, 0, 0, 0, size / 2, (size - 1) / 2}, + /*mode=*/torch::kConstant, + /*value=*/0); + div = detail::avg_pool3d( + div, + /*kernel_size=*/{size, 1, 1}, + /*stride=*/1, + /*padding=*/0, + /*ceil_mode=*/false, + /*count_include_pad=*/true, + /*divisor_override=*/std::nullopt) + .squeeze(1); + div = div.view(sizes); + } + div = div.mul(alpha).add(k).pow(beta); + return input / div; +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.local_response_norm +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::LocalResponseNormFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::local_response_norm(x, F::LocalResponseNormFuncOptions(2)); +/// ``` +inline Tensor local_response_norm( + const Tensor& input, + const LocalResponseNormFuncOptions& options) { + return detail::local_response_norm( + input, options.size(), options.alpha(), options.beta(), options.k()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor group_norm( + const Tensor& input, + int64_t num_groups, + const Tensor& weight, + const Tensor& bias, + double eps) { + return torch::group_norm( + input, + num_groups, + weight, + bias, + eps, + at::globalContext().userEnabledCuDNN()); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.group_norm +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::GroupNormFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5)); +/// ``` +inline Tensor group_norm( + const Tensor& input, + const GroupNormFuncOptions& options) { + return detail::group_norm( + input, + options.num_groups(), + options.weight(), + options.bias(), + options.eps()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/padding.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/padding.h new file mode 100644 index 0000000000000000000000000000000000000000..5ef8b6ff34492a032983aaffae9251241a94160a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/padding.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor pad( + const Tensor& input, + IntArrayRef pad, + PadFuncOptions::mode_t mode, + double value) { + const auto mode_enum = [&] { + if (std::holds_alternative(mode)) { + return at::padding_mode::constant; + } else if (std::holds_alternative(mode)) { + return at::padding_mode::reflect; + } else if (std::holds_alternative(mode)) { + return at::padding_mode::replicate; + } else if (std::holds_alternative(mode)) { + return at::padding_mode::circular; + } + TORCH_CHECK(false, "Unrecognised padding mode"); + }(); + + std::optional fill_value; + if (value != 0.0) { + fill_value = value; + } + return at::_pad_enum(input, pad, static_cast(mode_enum), fill_value); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.pad +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::PadFuncOptions` class to +/// learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, +/// 2}).mode(torch::kReplicate)); +/// ``` +inline Tensor pad(const Tensor& input, const PadFuncOptions& options) { + return detail::pad(input, options.pad(), options.mode(), options.value()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h new file mode 100644 index 0000000000000000000000000000000000000000..4d005f3568969fb5e667b823a7470d7d20e08cd9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h @@ -0,0 +1,43 @@ +#pragma once + +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor pixel_shuffle(const Tensor& input, int64_t upscale_factor) { + return torch::pixel_shuffle(input, upscale_factor); +} + +inline Tensor pixel_unshuffle(const Tensor& input, int64_t downscale_factor) { + return torch::pixel_unshuffle(input, downscale_factor); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.pixel_shuffle +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::PixelShuffleFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pixel_shuffle(x, F::PixelShuffleFuncOptions(2)); +/// ``` +inline Tensor pixel_shuffle( + const Tensor& input, + const PixelShuffleFuncOptions& options) { + return detail::pixel_shuffle(input, options.upscale_factor()); +} + +inline Tensor pixel_unshuffle( + const Tensor& input, + const PixelUnshuffleFuncOptions& options) { + return detail::pixel_unshuffle(input, options.downscale_factor()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/pooling.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..72aaca76f6f4dbab6525d593b000e821f4d2d627 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/pooling.h @@ -0,0 +1,1149 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn::functional { + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor avg_pool1d( + const Tensor& input, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + ExpandingArray<1> padding, + bool ceil_mode, + bool count_include_pad) { + return torch::avg_pool1d( + input, kernel_size, stride, padding, ceil_mode, count_include_pad); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.avg_pool1d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::AvgPool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2)); +/// ``` +inline Tensor avg_pool1d( + const Tensor& input, + const AvgPool1dFuncOptions& options) { + return avg_pool1d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor avg_pool2d( + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override) { + return torch::avg_pool2d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.avg_pool2d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::AvgPool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2)); +/// ``` +inline Tensor avg_pool2d( + const Tensor& input, + const AvgPool2dFuncOptions& options) { + return detail::avg_pool2d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor avg_pool3d( + const Tensor& input, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override) { + return torch::avg_pool3d( + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.avg_pool3d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::AvgPool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2)); +/// ``` +inline Tensor avg_pool3d( + const Tensor& input, + const AvgPool3dFuncOptions& options) { + return detail::avg_pool3d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor max_pool1d( + const Tensor& input, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + ExpandingArray<1> padding, + ExpandingArray<1> dilation, + bool ceil_mode) { + return torch::max_pool1d( + input, kernel_size, stride, padding, dilation, ceil_mode); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.max_pool1d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::MaxPool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2)); +/// ``` +inline Tensor max_pool1d( + const Tensor& input, + const MaxPool1dFuncOptions& options) { + return detail::max_pool1d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple max_pool1d_with_indices( + const Tensor& input, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + ExpandingArray<1> padding, + ExpandingArray<1> dilation, + bool ceil_mode) { + return torch::max_pool1d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for `torch::nn::functional::MaxPool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool1d_with_indices(x, F::MaxPool1dFuncOptions(3).stride(2)); +/// ``` +inline std::tuple max_pool1d_with_indices( + const Tensor& input, + const MaxPool1dFuncOptions& options) { + return detail::max_pool1d_with_indices( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor max_pool2d( + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + ExpandingArray<2> dilation, + bool ceil_mode) { + return torch::max_pool2d( + input, kernel_size, stride, padding, dilation, ceil_mode); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.max_pool2d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::MaxPool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2)); +/// ``` +inline Tensor max_pool2d( + const Tensor& input, + const MaxPool2dFuncOptions& options) { + return detail::max_pool2d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple max_pool2d_with_indices( + const Tensor& input, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + ExpandingArray<2> dilation, + bool ceil_mode) { + return torch::max_pool2d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for `torch::nn::functional::MaxPool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool2d_with_indices(x, F::MaxPool2dFuncOptions(3).stride(2)); +/// ``` +inline std::tuple max_pool2d_with_indices( + const Tensor& input, + const MaxPool2dFuncOptions& options) { + return detail::max_pool2d_with_indices( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor max_pool3d( + const Tensor& input, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + ExpandingArray<3> dilation, + bool ceil_mode) { + return torch::max_pool3d( + input, kernel_size, stride, padding, dilation, ceil_mode); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.max_pool3d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::MaxPool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2)); +/// ``` +inline Tensor max_pool3d( + const Tensor& input, + const MaxPool3dFuncOptions& options) { + return detail::max_pool3d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple max_pool3d_with_indices( + const Tensor& input, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + ExpandingArray<3> dilation, + bool ceil_mode) { + return torch::max_pool3d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for `torch::nn::functional::MaxPool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool3d_with_indices(x, F::MaxPool3dFuncOptions(3).stride(2)); +/// ``` +inline std::tuple max_pool3d_with_indices( + const Tensor& input, + const MaxPool3dFuncOptions& options) { + return detail::max_pool3d_with_indices( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple adaptive_max_pool1d_with_indices( + const Tensor& input, + ExpandingArray<1> output_size) { + return torch::adaptive_max_pool1d(input, output_size); +} +} // namespace detail + +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool1dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool1d_with_indices(x, F::AdaptiveMaxPool1dFuncOptions(3)); +/// ``` +inline std::tuple adaptive_max_pool1d_with_indices( + const Tensor& input, + const AdaptiveMaxPool1dFuncOptions& options) { + return detail::adaptive_max_pool1d_with_indices(input, options.output_size()); +} + +namespace detail { +inline Tensor adaptive_max_pool1d( + const Tensor& input, + ExpandingArray<1> output_size) { + return std::get<0>(adaptive_max_pool1d_with_indices(input, output_size)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.adaptive_max_pool1d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool1dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3)); +/// ``` +inline Tensor adaptive_max_pool1d( + const Tensor& input, + const AdaptiveMaxPool1dFuncOptions& options) { + return detail::adaptive_max_pool1d(input, options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple adaptive_max_pool2d_with_indices( + const Tensor& input, + ExpandingArrayWithOptionalElem<2> output_size) { + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + return torch::adaptive_max_pool2d(input, output_size_); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool2dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool2d_with_indices(x, F::AdaptiveMaxPool2dFuncOptions(3)); +/// ``` +inline std::tuple adaptive_max_pool2d_with_indices( + const Tensor& input, + const AdaptiveMaxPool2dFuncOptions& options) { + return detail::adaptive_max_pool2d_with_indices(input, options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor adaptive_max_pool2d( + const Tensor& input, + ExpandingArrayWithOptionalElem<2> output_size) { + return std::get<0>(adaptive_max_pool2d_with_indices(input, output_size)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.adaptive_max_pool2d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool2dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3)); +/// ``` +inline Tensor adaptive_max_pool2d( + const Tensor& input, + const AdaptiveMaxPool2dFuncOptions& options) { + return detail::adaptive_max_pool2d(input, options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple adaptive_max_pool3d_with_indices( + const Tensor& input, + ExpandingArrayWithOptionalElem<3> output_size) { + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + return torch::adaptive_max_pool3d(input, output_size_); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool3dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool3d_with_indices(x, F::AdaptiveMaxPool3dFuncOptions(3)); +/// ``` +inline std::tuple adaptive_max_pool3d_with_indices( + const Tensor& input, + const AdaptiveMaxPool3dFuncOptions& options) { + return detail::adaptive_max_pool3d_with_indices(input, options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor adaptive_max_pool3d( + const Tensor& input, + ExpandingArrayWithOptionalElem<3> output_size) { + return std::get<0>(adaptive_max_pool3d_with_indices(input, output_size)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.adaptive_max_pool3d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::AdaptiveMaxPool3dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3)); +/// ``` +inline Tensor adaptive_max_pool3d( + const Tensor& input, + const AdaptiveMaxPool3dFuncOptions& options) { + return detail::adaptive_max_pool3d(input, options.output_size()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor adaptive_avg_pool1d( + const Tensor& input, + ExpandingArray<1> output_size) { + return torch::adaptive_avg_pool1d(input, output_size); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.adaptive_avg_pool1d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::AdaptiveAvgPool1dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3)); +/// ``` +inline Tensor adaptive_avg_pool1d( + const Tensor& input, + const AdaptiveAvgPool1dFuncOptions& options) { + return detail::adaptive_avg_pool1d(input, options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor adaptive_avg_pool2d( + const Tensor& input, + ExpandingArrayWithOptionalElem<2> output_size) { + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + return torch::adaptive_avg_pool2d(input, output_size_); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.adaptive_avg_pool2d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::AdaptiveAvgPool2dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3)); +/// ``` +inline Tensor adaptive_avg_pool2d( + const Tensor& input, + const AdaptiveAvgPool2dFuncOptions& options) { + return detail::adaptive_avg_pool2d(input, options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor adaptive_avg_pool3d( + const Tensor& input, + ExpandingArrayWithOptionalElem<3> output_size) { + auto output_size_ = + torch::nn::modules::utils::_list_with_default(output_size, input.sizes()); + return torch::adaptive_avg_pool3d(input, output_size_); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.adaptive_avg_pool3d +/// about the exact behavior of this functional. +/// +/// See the documentation for +/// `torch::nn::functional::AdaptiveAvgPool3dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3)); +/// ``` +inline Tensor adaptive_avg_pool3d( + const Tensor& input, + const AdaptiveAvgPool3dFuncOptions& options) { + return detail::adaptive_avg_pool3d(input, options.output_size()); +} + +// ============================================================================ + +inline std::vector _unpool_output_size( + const Tensor& input, + const IntArrayRef& kernel_size, + const IntArrayRef& stride, + const IntArrayRef& padding, + const std::optional>& output_size) { + auto input_size = input.sizes(); + std::vector default_size; + for (const auto d : c10::irange(kernel_size.size())) { + default_size.push_back( + (input_size[input_size.size() - kernel_size.size() + d] - 1) * + stride[d] + + kernel_size[d] - 2 * padding[d]); + } + if (!output_size) { + return default_size; + } else { + std::vector output_size_; + if (output_size->size() == kernel_size.size() + 2) { + output_size_ = IntArrayRef(*output_size).slice(2).vec(); + } + if (output_size_.size() != kernel_size.size()) { + TORCH_CHECK( + false, + "output_size should be a sequence containing ", + kernel_size.size(), + " or ", + kernel_size.size() + 2, + " elements, but it has a length of '", + output_size_.size(), + "'"); + } + for (const auto d : c10::irange(kernel_size.size())) { + const auto min_size = default_size[d] - stride[d]; + const auto max_size = default_size[d] + stride[d]; + if (!(min_size <= output_size_[d] && output_size_[d] <= max_size)) { + TORCH_CHECK( + false, + "invalid output_size ", + output_size_, + " (dim ", + d, + " must be between ", + min_size, + " and ", + max_size, + ")"); + } + } + return output_size_; + } +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor max_unpool1d( + const Tensor& input, + const Tensor& indices, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + ExpandingArray<1> padding, + const std::optional>& output_size) { + auto output_size_ = + _unpool_output_size(input, kernel_size, stride, padding, output_size); + output_size_.push_back(1); + return torch::max_unpool2d( + input.unsqueeze(-1), indices.unsqueeze(-1), output_size_) + .squeeze(-1); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.max_unpool1d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::MaxUnpool1dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_unpool1d(x, indices, +/// F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); +/// ``` +inline Tensor max_unpool1d( + const Tensor& input, + const Tensor& indices, + const MaxUnpool1dFuncOptions& options) { + return detail::max_unpool1d( + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor max_unpool2d( + const Tensor& input, + const Tensor& indices, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + ExpandingArray<2> padding, + const std::optional>& output_size) { + auto output_size_ = + _unpool_output_size(input, kernel_size, stride, padding, output_size); + + return torch::max_unpool2d(input, indices, output_size_); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.max_unpool2d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::MaxUnpool2dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_unpool2d(x, indices, +/// F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); +/// ``` +inline Tensor max_unpool2d( + const Tensor& input, + const Tensor& indices, + const MaxUnpool2dFuncOptions& options) { + return detail::max_unpool2d( + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + options.output_size()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor max_unpool3d( + const Tensor& input, + const Tensor& indices, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + ExpandingArray<3> padding, + const std::optional>& output_size) { + auto output_size_ = + _unpool_output_size(input, kernel_size, stride, padding, output_size); + + return torch::max_unpool3d(input, indices, output_size_, stride, padding); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.max_unpool3d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::MaxUnpool3dFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3)); +/// ``` +inline Tensor max_unpool3d( + const Tensor& input, + const Tensor& indices, + const MaxUnpool3dFuncOptions& options) { + return detail::max_unpool3d( + input, + indices, + options.kernel_size(), + options.stride(), + options.padding(), + options.output_size()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple fractional_max_pool2d_with_indices( + const Tensor& input, + const ExpandingArray<2>& kernel_size, + const std::optional>& output_size, + const std::optional>& output_ratio, + const Tensor& _random_samples) { + if (output_size == std::nullopt && output_ratio == std::nullopt) { + TORCH_CHECK( + false, + "fractional_max_pool2d requires specifying either ", + "an output_size or an output_ratio"); + } + std::optional> output_size_ = output_size; + if (output_size_ == std::nullopt) { + TORCH_INTERNAL_ASSERT(output_ratio != std::nullopt); + output_size_ = { + (int64_t)(static_cast(input.size(-2)) * + (*output_ratio.value())[0]), + (int64_t)(static_cast(input.size(-1)) * + (*output_ratio.value())[1])}; + } + + Tensor _random_samples_ = _random_samples; + if (!_random_samples_.defined()) { + auto n_batch = input.dim() == 3 ? 1 : input.size(0); + _random_samples_ = torch::rand( + {n_batch, input.size(-3), 2}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + } + return torch::fractional_max_pool2d( + input, kernel_size, *output_size_, _random_samples_); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool2dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fractional_max_pool2d_with_indices(x, +/// F::FractionalMaxPool2dFuncOptions(3).output_size(2)); +/// ``` +inline std::tuple fractional_max_pool2d_with_indices( + const Tensor& input, + const FractionalMaxPool2dFuncOptions& options) { + return detail::fractional_max_pool2d_with_indices( + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + options._random_samples()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor fractional_max_pool2d( + const Tensor& input, + ExpandingArray<2> kernel_size, + std::optional> output_size, + std::optional> output_ratio, + const Tensor& _random_samples) { + return std::get<0>(fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, _random_samples)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool2dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fractional_max_pool2d(x, +/// F::FractionalMaxPool2dFuncOptions(3).output_size(2)); +/// ``` +inline Tensor fractional_max_pool2d( + const Tensor& input, + const FractionalMaxPool2dFuncOptions& options) { + return detail::fractional_max_pool2d( + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + options._random_samples()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline std::tuple fractional_max_pool3d_with_indices( + const Tensor& input, + const ExpandingArray<3>& kernel_size, + const std::optional>& output_size, + const std::optional>& output_ratio, + const Tensor& _random_samples) { + if (output_size == std::nullopt && output_ratio == std::nullopt) { + TORCH_CHECK( + false, + "fractional_max_pool3d requires specifying either ", + "an output_size or an output_ratio"); + } + + std::optional> output_size_ = output_size; + if (output_size_ == std::nullopt) { + TORCH_INTERNAL_ASSERT(output_ratio != std::nullopt); + output_size_ = { + (int64_t)(static_cast(input.size(-3)) * + (*output_ratio.value())[0]), + (int64_t)(static_cast(input.size(-2)) * + (*output_ratio.value())[1]), + (int64_t)(static_cast(input.size(-1)) * + (*output_ratio.value())[2])}; + } + + Tensor _random_samples_ = _random_samples; + if (!_random_samples_.defined()) { + auto n_batch = input.dim() == 4 ? 1 : input.size(0); + _random_samples_ = torch::rand( + {n_batch, input.size(-4), 3}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); + } + return torch::fractional_max_pool3d( + input, kernel_size, *output_size_, _random_samples_); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool3dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fractional_max_pool3d_with_indices(x, +/// F::FractionalMaxPool3dFuncOptions(3).output_size(2)); +/// ``` +inline std::tuple fractional_max_pool3d_with_indices( + const Tensor& input, + const FractionalMaxPool3dFuncOptions& options) { + return detail::fractional_max_pool3d_with_indices( + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + options._random_samples()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor fractional_max_pool3d( + const Tensor& input, + ExpandingArray<3> kernel_size, + std::optional> output_size, + std::optional> output_ratio, + const Tensor& _random_samples) { + return std::get<0>(fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, _random_samples)); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See the documentation for +/// `torch::nn::functional::FractionalMaxPool3dFuncOptions` class to learn what +/// optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fractional_max_pool3d(x, +/// F::FractionalMaxPool3dFuncOptions(3).output_size(2)); +/// ``` +inline Tensor fractional_max_pool3d( + const Tensor& input, + const FractionalMaxPool3dFuncOptions& options) { + return detail::fractional_max_pool3d( + input, + options.kernel_size(), + options.output_size(), + options.output_ratio(), + options._random_samples()); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor lp_pool1d( + const Tensor& input, + double norm_type, + ExpandingArray<1> kernel_size, + ExpandingArray<1> stride, + bool ceil_mode) { + Tensor out = detail::avg_pool1d( + input.pow(norm_type), + kernel_size, + stride, + /*padding=*/0, + ceil_mode, + /*count_include_pad=*/true); + + return (torch::sign(out) * relu(torch::abs(out))) + .mul((*kernel_size)[0]) + .pow(1. / norm_type); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.lp_pool1d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::LPPool1dFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::lp_pool1d(x, F::LPPool1dFuncOptions(2, 3).stride(2)); +/// ``` +inline Tensor lp_pool1d( + const Tensor& input, + const LPPool1dFuncOptions& options) { + return detail::lp_pool1d( + input, + options.norm_type(), + options.kernel_size(), + options.stride(), + options.ceil_mode()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor lp_pool2d( + const Tensor& input, + double norm_type, + ExpandingArray<2> kernel_size, + ExpandingArray<2> stride, + bool ceil_mode) { + auto kw = (*kernel_size)[0]; + auto kh = (*kernel_size)[1]; + Tensor out = detail::avg_pool2d( + input.pow(norm_type), + kernel_size, + stride, + /*padding=*/0, + ceil_mode, + /*count_include_pad=*/true, + /*divisor_override=*/std::nullopt); + + return (torch::sign(out) * relu(torch::abs(out))) + .mul(kw * kh) + .pow(1. / norm_type); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.lp_pool2d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::LPPool2dFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::lp_pool2d(x, F::LPPool2dFuncOptions(2, {2, 3}).stride(2)); +/// ``` +inline Tensor lp_pool2d( + const Tensor& input, + const LPPool2dFuncOptions& options) { + return detail::lp_pool2d( + input, + options.norm_type(), + options.kernel_size(), + options.stride(), + options.ceil_mode()); +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor lp_pool3d( + const Tensor& input, + double norm_type, + ExpandingArray<3> kernel_size, + ExpandingArray<3> stride, + bool ceil_mode) { + auto kd = (*kernel_size)[0]; + auto kw = (*kernel_size)[1]; + auto kh = (*kernel_size)[2]; + Tensor out = detail::avg_pool3d( + input.pow(norm_type), + kernel_size, + stride, + /*padding=*/0, + ceil_mode, + /*count_include_pad=*/true, + /*divisor_override=*/std::nullopt); + + return (torch::sign(out) * relu(torch::abs(out))) + .mul(kd * kw * kh) + .pow(1. / norm_type); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.lp_pool3d +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::LPPool3dFuncOptions` class +/// to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::lp_pool3d(x, F::LPPool3dFuncOptions(3, {3, 3, 5}).stride(3)); +/// ``` +inline Tensor lp_pool3d( + const Tensor& input, + const LPPool3dFuncOptions& options) { + return detail::lp_pool3d( + input, + options.norm_type(), + options.kernel_size(), + options.stride(), + options.ceil_mode()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/upsampling.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/upsampling.h new file mode 100644 index 0000000000000000000000000000000000000000..ace73152d88ca59cdcd180a4bfa1225588fd507b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/upsampling.h @@ -0,0 +1,286 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::nn::functional { + +inline std::vector _interp_output_size( + int64_t dim, + std::tuple< + Tensor, + std::optional>, + std::optional>, + std::optional> closed_over_args) { + auto [input, size, scale_factor, recompute_scale_factor] = + std::move(closed_over_args); + if (size == std::nullopt && scale_factor == std::nullopt) { + TORCH_CHECK(false, "either size or scale_factor should be defined"); + } + if (size != std::nullopt && scale_factor != std::nullopt) { + TORCH_CHECK(false, "only one of size or scale_factor should be defined"); + } + if (scale_factor != std::nullopt) { + if (static_cast(scale_factor.value().size()) != dim) { + TORCH_CHECK( + false, + "scale_factor shape must match input shape. ", + "Input is ", + dim, + "D, scale_factor size is ", + torch::ArrayRef(*scale_factor)); + } + } + if (size != std::nullopt) { + return *size; + } + + TORCH_INTERNAL_ASSERT(scale_factor != std::nullopt); + auto scale_factors = *scale_factor; + + if (recompute_scale_factor == std::nullopt) { + // only warn when the scales have floating values since + // the result for ints is the same with/without recompute_scale_factor + bool is_float_scale_factor = false; + for (double scale : scale_factors) { + is_float_scale_factor = floor(scale) != scale; + if (is_float_scale_factor) { + break; + } + } + if (is_float_scale_factor) { + TORCH_WARN( + "The default behavior for interpolate/upsample with float scale_factor changed " + "in 1.6.0 to align with other frameworks/libraries, and uses scale_factor directly, " + "instead of relying on the computed output size. " + "If you wish to keep the old behavior, please set recompute_scale_factor=True. " + "See the documentation of nn.Upsample for details. "); + } + } + + std::vector ret; + for (const auto i : c10::irange(dim)) { + ret.emplace_back(static_cast( + floor(static_cast(input.size(i + 2)) * scale_factors[i]))); + } + return ret; +} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor interpolate( + const Tensor& input, + const std::optional>& size, + const std::optional>& scale_factor, + InterpolateFuncOptions::mode_t mode, + std::optional align_corners, + std::optional recompute_scale_factor, + bool antialias) { + if (std::holds_alternative(mode) || + std::get_if(&mode)) { + if (align_corners != std::nullopt) { + TORCH_CHECK( + false, + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear"); + } + } else { + if (align_corners == std::nullopt) { + TORCH_WARN( + "Default upsampling behavior when mode=", + enumtype::get_enum_name(mode), + " is changed " + "to align_corners=False since 0.4.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of nn.Upsample for details."); + align_corners = false; + } + } + + TORCH_CHECK( + input.dim() >= 3 && input.dim() <= 5, + "Input Error: Only 3D, 4D and 5D input Tensors supported " + "(got ", + input.dim(), + "D) for the modes: nearest | linear | bilinear | bicubic | trilinear " + "(got ", + enumtype::get_enum_name(mode), + ")"); + + auto scale_factor_len = input.dim() - 2; + std::vector> scale_factor_list( + scale_factor_len, std::nullopt); + if (scale_factor != std::nullopt && !recompute_scale_factor.value_or(false)) { + auto _scale_factor_repeated = *scale_factor; + scale_factor_list = {}; + for (const auto& elem : _scale_factor_repeated) { + scale_factor_list.emplace_back(elem); + } + } + + if (antialias && + !(input.dim() == 4 && + (std::get_if(&mode) || + std::get_if(&mode)))) { + TORCH_CHECK( + false, + "Anti-alias option is only supported for bilinear and bicubic modes"); + } + + auto closed_over_args = + std::make_tuple(input, size, scale_factor, recompute_scale_factor); + if (input.dim() == 3 && std::get_if(&mode)) { + return torch::upsample_nearest1d( + input, + _interp_output_size(1, std::move(closed_over_args)), + scale_factor_list.at(0)); + } else if (input.dim() == 4 && std::get_if(&mode)) { + return torch::upsample_nearest2d( + input, + _interp_output_size(2, std::move(closed_over_args)), + scale_factor_list.at(0), + scale_factor_list.at(1)); + } else if (input.dim() == 5 && std::get_if(&mode)) { + return torch::upsample_nearest3d( + input, + _interp_output_size(3, std::move(closed_over_args)), + scale_factor_list.at(0), + scale_factor_list.at(1), + scale_factor_list.at(2)); + } else if (input.dim() == 3 && std::get_if(&mode)) { + return torch::_upsample_nearest_exact1d( + input, + _interp_output_size(1, std::move(closed_over_args)), + scale_factor_list.at(0)); + } else if (input.dim() == 4 && std::get_if(&mode)) { + return torch::_upsample_nearest_exact2d( + input, + _interp_output_size(2, std::move(closed_over_args)), + scale_factor_list.at(0), + scale_factor_list.at(1)); + } else if (input.dim() == 5 && std::get_if(&mode)) { + return torch::_upsample_nearest_exact3d( + input, + _interp_output_size(3, std::move(closed_over_args)), + scale_factor_list.at(0), + scale_factor_list.at(1), + scale_factor_list.at(2)); + } else if (input.dim() == 3 && std::get_if(&mode)) { + return detail::adaptive_avg_pool1d( + input, _interp_output_size(1, std::move(closed_over_args))); + } else if (input.dim() == 4 && std::get_if(&mode)) { + return detail::adaptive_avg_pool2d( + input, _interp_output_size(2, std::move(closed_over_args))); + } else if (input.dim() == 5 && std::get_if(&mode)) { + return detail::adaptive_avg_pool3d( + input, _interp_output_size(3, std::move(closed_over_args))); + } else if (input.dim() == 3 && std::get_if(&mode)) { + TORCH_CHECK( + align_corners != std::nullopt, "align_corners should be specified."); + return torch::upsample_linear1d( + input, + _interp_output_size(1, std::move(closed_over_args)), + *align_corners, + scale_factor_list.at(0)); + } else if (input.dim() == 3 && std::get_if(&mode)) { + TORCH_CHECK(false, "Got 3D input, but bilinear mode needs 4D input"); + } else if (input.dim() == 3 && std::get_if(&mode)) { + TORCH_CHECK(false, "Got 3D input, but trilinear mode needs 5D input"); + } else if (input.dim() == 4 && std::get_if(&mode)) { + TORCH_CHECK(false, "Got 4D input, but linear mode needs 3D input"); + } else if (input.dim() == 4 && std::get_if(&mode)) { + TORCH_CHECK( + align_corners != std::nullopt, "align_corners should be specified."); + if (antialias) { + return torch::_upsample_bilinear2d_aa( + input, + _interp_output_size(2, std::move(closed_over_args)), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1)); + } + return torch::upsample_bilinear2d( + input, + _interp_output_size(2, std::move(closed_over_args)), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1)); + } else if (input.dim() == 4 && std::get_if(&mode)) { + TORCH_CHECK(false, "Got 4D input, but trilinear mode needs 5D input"); + } else if (input.dim() == 5 && std::get_if(&mode)) { + TORCH_CHECK(false, "Got 5D input, but linear mode needs 3D input"); + } else if (input.dim() == 5 && std::get_if(&mode)) { + TORCH_CHECK(false, "Got 5D input, but bilinear mode needs 4D input"); + } else if (input.dim() == 5 && std::get_if(&mode)) { + TORCH_CHECK( + align_corners != std::nullopt, "align_corners should be specified."); + return torch::upsample_trilinear3d( + input, + _interp_output_size(3, std::move(closed_over_args)), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1), + scale_factor_list.at(2)); + } else if (input.dim() == 4 && std::get_if(&mode)) { + TORCH_CHECK( + align_corners != std::nullopt, "align_corners should be specified."); + if (antialias) { + return torch::_upsample_bicubic2d_aa( + input, + _interp_output_size(2, std::move(closed_over_args)), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1)); + } + return torch::upsample_bicubic2d( + input, + _interp_output_size(2, std::move(closed_over_args)), + *align_corners, + scale_factor_list.at(0), + scale_factor_list.at(1)); + } else { + TORCH_CHECK( + false, + "Input Error: Only 3D, 4D and 5D input Tensors supported " + "(got ", + input.dim(), + "D) for the modes: nearest | linear | bilinear | bicubic | trilinear " + "(got ", + enumtype::get_enum_name(mode), + ")"); + } +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.interpolate +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::InterpolateFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::interpolate(input, +/// F::InterpolateFuncOptions().size({4}).mode(torch::kNearest)); +/// ``` +inline Tensor interpolate( + const Tensor& input, + const InterpolateFuncOptions& options = {}) { + return detail::interpolate( + input, + options.size(), + options.scale_factor(), + options.mode(), + options.align_corners(), + options.recompute_scale_factor(), + options.antialias()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/vision.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/vision.h new file mode 100644 index 0000000000000000000000000000000000000000..78a015dcff856660a93f85c29990731b75336698 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/functional/vision.h @@ -0,0 +1,119 @@ +#pragma once + +#include +#include + +namespace torch::nn::functional { + +inline Tensor affine_grid( + const Tensor& theta, + const IntArrayRef& size, + bool align_corners = false) { + // enforce floating point dtype on theta + TORCH_CHECK( + theta.is_floating_point(), + "Expected theta to have floating point type, but got ", + theta.dtype()); + + // check that shapes and sizes match + if (size.size() == 4) { + TORCH_CHECK( + theta.dim() == 3 && theta.size(-2) == 2 && theta.size(-1) == 3, + "Expected a batch of 2D affine matrices of shape Nx2x3 for size ", + size, + ". Got ", + theta.sizes(), + "."); + } else if (size.size() == 5) { + TORCH_CHECK( + theta.dim() == 3 && theta.size(-2) == 3 && theta.size(-1) == 4, + "Expected a batch of 3D affine matrices of shape Nx3x4 for size ", + size, + ". Got ", + theta.sizes(), + "."); + } else { + TORCH_CHECK( + false, + "affine_grid only supports 4D and 5D sizes, ", + "for 2D and 3D affine transforms, respectively. ", + "Got size ", + size); + } + + if (*std::min_element(size.begin(), size.end()) <= 0) { + TORCH_CHECK(false, "Expected non-zero, positive output size. Got ", size); + } + + return torch::affine_grid_generator(theta, size, align_corners); +} + +// ============================================================================ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace detail { +inline Tensor grid_sample( + const Tensor& input, + const Tensor& grid, + GridSampleFuncOptions::mode_t mode, + GridSampleFuncOptions::padding_mode_t padding_mode, + std::optional align_corners) { + int64_t mode_enum = 0, padding_mode_enum = 0; + + if (std::holds_alternative(mode)) { + mode_enum = 0; + } else if (std::holds_alternative(mode)) { + mode_enum = 1; + } else { /// mode == 'bicubic' + mode_enum = 2; + } + + if (std::holds_alternative(padding_mode)) { + padding_mode_enum = 0; + } else if (std::holds_alternative(padding_mode)) { + padding_mode_enum = 1; + } else { /// padding_mode == 'reflection' + padding_mode_enum = 2; + } + + if (!align_corners.has_value()) { + TORCH_WARN( + "Default grid_sample and affine_grid behavior has changed ", + "to align_corners=False since 1.3.0. Please specify ", + "align_corners=True if the old behavior is desired. ", + "See the documentation of grid_sample for details."); + align_corners = false; + } + + return torch::grid_sampler( + input, grid, mode_enum, padding_mode_enum, align_corners.value()); +} +} // namespace detail +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +/// See +/// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.grid_sample +/// about the exact behavior of this functional. +/// +/// See the documentation for `torch::nn::functional::GridSampleFuncOptions` +/// class to learn what optional arguments are supported for this functional. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::grid_sample(input, grid, +/// F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true)); +/// ``` +inline Tensor grid_sample( + const Tensor& input, + const Tensor& grid, + const GridSampleFuncOptions& options = {}) { + return detail::grid_sample( + input, + grid, + options.mode(), + options.padding_mode(), + options.align_corners()); +} + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/init.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/init.h new file mode 100644 index 0000000000000000000000000000000000000000..d7a5476653c70dc19671bc7b9a4507b1a75f04e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/init.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include + +namespace torch { + +namespace nn::init { + +using NonlinearityType = std::variant< + enumtype::kLinear, + enumtype::kConv1D, + enumtype::kConv2D, + enumtype::kConv3D, + enumtype::kConvTranspose1D, + enumtype::kConvTranspose2D, + enumtype::kConvTranspose3D, + enumtype::kSigmoid, + enumtype::kTanh, + enumtype::kReLU, + enumtype::kLeakyReLU>; + +using FanModeType = std::variant; + +} // namespace nn::init + +namespace nn::init { + +/// Return the recommended gain value for the given nonlinearity function. +TORCH_API double calculate_gain( + NonlinearityType nonlinearity, + double param = 0.01); + +/// Fills the given `tensor` with the provided `value` in-place, and returns it. +/// No gradient will be recorded for this operation. +TORCH_API Tensor constant_(Tensor tensor, Scalar value); + +/// Fills the given `tensor` with the Dirac delta function in-place, and returns +/// it. No gradient will be recorded for this operation. +TORCH_API Tensor dirac_(Tensor tensor); + +/// Fills the given 2-dimensional `matrix` with an identity matrix. +/// No gradient will be recorded for this operation. +TORCH_API Tensor eye_(Tensor matrix); + +/// Fills the given 2-dimensional `matrix` with values drawn from a normal +/// distribution parameterized by `mean` and `std`. +/// No gradient will be recorded for this operation. +TORCH_API Tensor normal_(Tensor tensor, double mean = 0, double std = 1); + +/// Fills the given `tensor` with ones. +/// No gradient will be recorded for this operation. +TORCH_API Tensor ones_(Tensor tensor); + +/// Fills the input `Tensor` with a (semi) orthogonal matrix, as described in +/// "Exact solutions to the nonlinear dynamics of learning in deep linear neural +/// networks" - Saxe, A. et al. (2013). The input tensor must have at least 2 +/// dimensions, and for tensors with more than 2 dimensions the trailing +/// dimensions are flattened. +/// No gradient will be recorded for this operation. +TORCH_API Tensor orthogonal_(Tensor tensor, double gain = 1.0); + +/// Fills the 2D input `Tensor` as a sparse matrix, where the +/// non-zero elements will be drawn from a centered normal distribution +/// with the given standard deviation `std`, as described in "Deep learning via +/// Hessian-free optimization" - Martens, J. (2010). The `sparsity` is a real +/// value between 0 and 1 that controls the fraction of elements in each column +/// to be set to zero. +/// No gradient will be recorded for this operation. +TORCH_API Tensor sparse_(Tensor tensor, double sparsity, double std = 0.01); + +/// Fills the given 2-dimensional `matrix` with values drawn from a uniform +/// distribution parameterized by `low` and `high`. +/// No gradient will be recorded for this operation. +TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1); + +/// Fills the input `Tensor` with values according to the method +/// described in "Delving deep into rectifiers: Surpassing human-level +/// performance on ImageNet classification" - He, K. et al. (2015), using a +/// normal distribution. Also known as He initialization. +/// No gradient will be recorded for this operation. +TORCH_API Tensor kaiming_normal_( + Tensor tensor, + double a = 0, + FanModeType mode = torch::kFanIn, + NonlinearityType nonlinearity = torch::kLeakyReLU); + +/// Fills the input `Tensor` with values according to the method +/// described in "Delving deep into rectifiers: Surpassing human-level +/// performance on ImageNet classification" - He, K. et al. (2015), using a +/// uniform distribution. Also known as He initialization. +/// No gradient will be recorded for this operation. +TORCH_API Tensor kaiming_uniform_( + Tensor tensor, + double a = 0, + FanModeType mode = torch::kFanIn, + NonlinearityType nonlinearity = torch::kLeakyReLU); + +/// Fills the input `Tensor` with values according to the method +/// described in "Understanding the difficulty of training deep feedforward +/// neural networks" - Glorot, X. & Bengio, Y. (2010). Values are scaled by the +/// `gain` parameter. No gradient will be recorded for this operation. +TORCH_API Tensor xavier_normal_(Tensor tensor, double gain = 1.0); + +/// Fills the input `Tensor` with values according to the method +/// described in "Understanding the difficulty of training deep feedforward +/// neural networks" - Glorot, X. & Bengio, Y. (2010), using a uniform +/// distribution. Values are scaled by the `gain` parameter +/// No gradient will be recorded for this operation. +TORCH_API Tensor xavier_uniform_(Tensor tensor, double gain = 1.0); + +/// Fills the given `tensor` with zeros. +/// No gradient will be recorded for this operation. +TORCH_API Tensor zeros_(Tensor tensor); + +TORCH_API std::tuple _calculate_fan_in_and_fan_out( + const Tensor& tensor); + +} // namespace nn::init + +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/module.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/module.h new file mode 100644 index 0000000000000000000000000000000000000000..4eff40199ff43c13a966e7367bab94b8c5955a0c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/module.h @@ -0,0 +1,700 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace torch::nn { + +/// The base class for all modules in PyTorch. +/// +/// \rst +/// .. note:: +/// The design and implementation of this class is largely based on the Python +/// API. You may want to consult the python documentation for +/// :py:class:`pytorch:torch.nn.Module` for further clarification on certain +/// methods or behavior. +/// \endrst +/// +/// A `Module` is an abstraction over the implementation of some function or +/// algorithm, possibly associated with some persistent data. A `Module` may +/// contain further `Module`s ("submodules"), each with their own +/// implementation, persistent data and further submodules. `Module`s can thus +/// be said to form a recursive tree structure. A `Module` is registered as a +/// submodule to another `Module` by calling `register_module()`, typically from +/// within a parent module's constructor. +/// +/// A distinction is made between three kinds of persistent data that may be +/// associated with a `Module`: +/// +/// 1. *Parameters*: tensors that record gradients, typically weights updated +/// during the backward step (e.g. the `weight` of a `Linear` module), +/// 2. *Buffers*: tensors that do not record gradients, typically updated during +/// the forward step, such as running statistics (e.g. `mean` and `variance` +/// in the `BatchNorm` module), +/// 3. Any additional state, not necessarily tensors, required for the +/// implementation or configuration of a `Module`. +/// +/// The first two kinds of state are special in that they may be registered +/// with the `Module` system to allow convenient access and batch configuration. +/// For example, registered parameters in any `Module` may be iterated over via +/// the `parameters()` accessor. Further, changing the data type of a `Module`'s +/// registered parameters can be done conveniently via `Module::to()`, e.g. +/// `module->to(torch::kCUDA)` to move all parameters to GPU memory. Lastly, +/// registered parameters and buffers are handled specially during a `clone()` +/// operation, which performs a deepcopy of a cloneable `Module` hierarchy. +/// +/// Parameters are registered with a `Module` via `register_parameter`. Buffers +/// are registered separately via `register_buffer`. These methods are part of +/// the public API of `Module` and are typically invoked from within a +/// concrete `Module`s constructor. +class TORCH_API Module : public std::enable_shared_from_this { + public: + using ModuleApplyFunction = std::function; + using ConstModuleApplyFunction = std::function; + using NamedModuleApplyFunction = + std::function; + using ConstNamedModuleApplyFunction = + std::function; + using ModulePointerApplyFunction = + std::function&)>; + using NamedModulePointerApplyFunction = + std::function&)>; + + /// Tells the base `Module` about the name of the submodule. + explicit Module(std::string name); + + /// Constructs the module without immediate knowledge of the submodule's name. + /// The name of the submodule is inferred via RTTI (if possible) the first + /// time `.name()` is invoked. + Module(); + Module(const Module&) = default; + Module& operator=(const Module&) = default; + Module(Module&&) noexcept = default; + Module& operator=(Module&&) noexcept = default; + + virtual ~Module() = default; + + /// Returns the name of the `Module`. + /// + /// A `Module` has an associated `name`, which is a string representation of + /// the kind of concrete `Module` it represents, such as `"Linear"` for the + /// `Linear` module. Under most circumstances, this name is automatically + /// inferred via runtime type information (RTTI). In the unusual circumstance + /// that you have this feature disabled, you may want to manually name your + /// `Module`s by passing the string name to the `Module` base class' + /// constructor. + const std::string& name() const noexcept; + + /// Performs a recursive deep copy of the module and all its registered + /// parameters, buffers and submodules. + /// + /// Optionally, this method sets the current device + /// to the one supplied before cloning. If no device is given, each + /// parameter and buffer will be moved to the device of its source. + /// + /// \rst + /// .. attention:: + /// Attempting to call the `clone()` method inherited from the base `Module` + /// class (the one documented here) will fail. To inherit an actual + /// implementation of `clone()`, you must subclass `Cloneable`. `Cloneable` + /// is templatized on the concrete module type, and can thus properly copy a + /// `Module`. This method is provided on the base class' API solely for an + /// easier-to-use polymorphic interface. + /// \endrst + virtual std::shared_ptr clone( + const std::optional& device = std::nullopt) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `Module&`. + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](nn::Module& module) { + /// std::cout << module.name() << std::endl; + /// }); + /// \endrst + void apply(const ModuleApplyFunction& function); + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const Module&`. + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const nn::Module& module) { + /// std::cout << module.name() << std::endl; + /// }); + /// \endrst + void apply(const ConstModuleApplyFunction& function) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::string&` for the key of the module, + /// and a `Module&`. The key of the module itself is the empty string. If + /// `name_prefix` is given, it is prepended to every key as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::string& key, nn::Module& module) { + /// std::cout << key << ": " << module.name() << std::endl; + /// }); + /// \endrst + void apply( + const NamedModuleApplyFunction& function, + const std::string& name_prefix = std::string()); + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::string&` for the key of the module, + /// and a `const Module&`. The key of the module itself is the empty string. + /// If `name_prefix` is given, it is prepended to every key as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::string& key, const nn::Module& module) { + /// std::cout << key << ": " << module.name() << std::endl; + /// }); + /// \endrst + void apply( + const ConstNamedModuleApplyFunction& function, + const std::string& name_prefix = std::string()) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::shared_ptr&`. + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::shared_ptr& module) { + /// std::cout << module->name() << std::endl; + /// }); + /// \endrst + void apply(const ModulePointerApplyFunction& function) const; + + /// Applies the `function` to the `Module` and recursively to every submodule. + /// The function must accept a `const std::string&` for the key of the module, + /// and a `const std::shared_ptr&`. The key of the module itself is + /// the empty string. If `name_prefix` is given, it is prepended to every key + /// as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. code-block:: cpp + /// MyModule module; + /// module->apply([](const std::string& key, + /// const std::shared_ptr& module) { + /// std::cout << key << ": " << module->name() << std::endl; + /// }); + /// \endrst + void apply( + const NamedModulePointerApplyFunction& function, + const std::string& name_prefix = std::string()) const; + + /// Returns the parameters of this `Module` and if `recurse` is true, also + /// recursively of every submodule. + std::vector parameters(bool recurse = true) const; + + /// Returns an `OrderedDict` with the parameters of this `Module` along with + /// their keys, and if `recurse` is true also recursively of every submodule. + OrderedDict named_parameters(bool recurse = true) const; + + /// Returns the buffers of this `Module` and if `recurse` is true, also + /// recursively of every submodule. + std::vector buffers(bool recurse = true) const; + + /// Returns an `OrderedDict` with the buffers of this `Module` along with + /// their keys, and if `recurse` is true also recursively of every submodule. + OrderedDict named_buffers(bool recurse = true) const; + + /// Returns the submodules of this `Module` (the entire submodule hierarchy) + /// and if `include_self` is true, also inserts a `shared_ptr` to this module + /// in the first position. + /// + /// \rst + /// .. warning:: + /// Only pass `include_self` as `true` if this `Module` is stored in a + /// `shared_ptr`! Otherwise an exception will be thrown. You may still call + /// this method with `include_self` set to false if your `Module` is not + /// stored in a `shared_ptr`. + /// \endrst + std::vector> modules(bool include_self = true) const; + + /// Returns an `OrderedDict` of the submodules of this `Module` (the entire + /// submodule hierarchy) and their keys, and if `include_self` is true, also + /// inserts a `shared_ptr` to this module in the first position. If + /// `name_prefix` is given, it is prepended to every key as + /// `.` (and just `name_prefix` for the module itself). + /// + /// \rst + /// .. warning:: + /// Only pass `include_self` as `true` if this `Module` is stored in a + /// `shared_ptr`! Otherwise an exception will be thrown. You may still call + /// this method with `include_self` set to false if your `Module` is not + /// stored in a `shared_ptr`. + /// \endrst + OrderedDict> named_modules( + const std::string& name_prefix = std::string(), + bool include_self = true) const; + + /// Returns the direct submodules of this `Module`. + std::vector> children() const; + + /// Returns an `OrderedDict` of the direct submodules of this `Module` and + /// their keys. + OrderedDict> named_children() const; + + /// Enables "training" mode. + virtual void train(bool on = true); + + /// Calls train(false) to enable "eval" mode. + /// Do not override this method, override `train()` instead. + void eval(); + + /// True if the module is in training mode. + /// + /// Every `Module` has a boolean associated with it that determines whether + /// the `Module` is currently in *training* mode (set via `.train()`) or in + /// *evaluation* (inference) mode (set via `.eval()`). This property is + /// exposed via `is_training()`, and may be used by the implementation of a + /// concrete module to modify its runtime behavior. See the `BatchNorm` or + /// `Dropout` modules for examples of `Module`s that use different code paths + /// depending on this property. + virtual bool is_training() const noexcept; + + /// Recursively casts all parameters to the given `dtype` and `device`. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + virtual void to( + torch::Device device, + torch::Dtype dtype, + bool non_blocking = false); + + /// Recursively casts all parameters to the given dtype. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + virtual void to(torch::Dtype dtype, bool non_blocking = false); + + /// Recursively moves all parameters to the given device. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + virtual void to(torch::Device device, bool non_blocking = false); + + /// Recursively zeros out the `grad` value of each registered parameter. + virtual void zero_grad(bool set_to_none = true); + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module->apply(initialize_weights); + /// \endrst + template + typename ModuleType::ContainedType* as() noexcept; + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module->apply(initialize_weights); + /// \endrst + template + const typename ModuleType::ContainedType* as() const noexcept; + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module.apply(initialize_weights); + /// \endrst + template < + typename ModuleType, + typename = torch::detail::disable_if_module_holder_t> + ModuleType* as() noexcept; + + /// Attempts to cast this `Module` to the given `ModuleType`. + /// + /// This method is useful when calling `apply()`. + /// \rst + /// .. code-block:: cpp + /// + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module.apply(initialize_weights); + /// \endrst + template < + typename ModuleType, + typename = torch::detail::disable_if_module_holder_t> + const ModuleType* as() const noexcept; + + /// Serializes the `Module` into the given `OutputArchive`. + /// + /// If the `Module` contains unserializable submodules (e.g. + /// `nn::Functional`), those submodules are skipped when serializing. + virtual void save(serialize::OutputArchive& archive) const; + + /// Deserializes the `Module` from the given `InputArchive`. + /// + /// If the `Module` contains unserializable submodules (e.g. + /// `nn::Functional`), we don't check the existence of those submodules in the + /// `InputArchive` when deserializing. + virtual void load(serialize::InputArchive& archive); + + /// Streams a pretty representation of the `Module` into the given `stream`. + /// By default, this representation will be the name of the module (taken from + /// `name()`), followed by a recursive pretty print of all of the `Module`'s + /// submodules. + /// + /// Override this method to change the pretty print. The input + /// `stream` should be returned from the method, to allow easy chaining. + virtual void pretty_print(std::ostream& stream) const; + + /// Returns whether the `Module` is serializable. + virtual bool is_serializable() const; + + /// Registers a parameter with this `Module`. + /// + /// A parameter should be any gradient-recording tensor used in the + /// implementation of your `Module`. Registering it makes it available to + /// methods such as `parameters()`, `clone()` or `to().` + /// + /// Note that registering an undefined Tensor (e.g. + /// `module.register_parameter("param", Tensor())`) is allowed, and is + /// equivalent to `module.register_parameter("param", None)` in Python API. + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// weight_ = register_parameter("weight", torch::randn({A, B})); + /// } + /// \endrst + Tensor& register_parameter( + std::string name, + Tensor tensor, + bool requires_grad = true); + + /// Registers a buffer with this `Module`. + /// + /// A buffer is intended to be state in your module that does not record + /// gradients, such as running statistics. Registering it makes it available + /// to methods such as `buffers()`, `clone()` or `to(). + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// mean_ = register_buffer("mean", torch::empty({num_features_})); + /// } + /// \endrst + Tensor& register_buffer(std::string name, Tensor tensor); + + /// Registers a submodule with this `Module`. + /// + /// Registering a module makes it available to methods such as `modules()`, + /// `clone()` or `to()`. + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); + /// } + /// \endrst + template + std::shared_ptr register_module( + std::string name, + std::shared_ptr module); + + /// Registers a submodule with this `Module`. + /// + /// This method deals with `ModuleHolder`s. + /// + /// Registering a module makes it available to methods such as `modules()`, + /// `clone()` or `to()`. + /// + /// \rst + /// .. code-block:: cpp + /// + /// MyModule::MyModule() { + /// submodule_ = register_module("linear", torch::nn::Linear(3, 4)); + /// } + /// \endrst + template + std::shared_ptr register_module( + std::string name, + ModuleHolder module_holder); + + /// Replaces a registered submodule with this `Module`. + /// + /// This takes care of the registration, if you used submodule members, you + /// should + // assign the submodule as well, i.e. use as + /// module->submodule_ = module->replace_module("linear", + /// torch::nn::Linear(3, 4)); + /// It only works when a module of the name is already registered. + /// + /// This is useful for replacing a module after initialization, e.g. + /// for finetuning. + template + std::shared_ptr replace_module( + const std::string& name, + std::shared_ptr module); + + /// Replaces a registered submodule with this `Module`. + /// This method deals with `ModuleHolder`s. + /// + /// This takes care of the registration, if you used submodule members, you + /// should + // assign the submodule as well, i.e. use as + /// module->submodule_ = module->replace_module("linear", linear_holder); + /// It only works when a module of the name is already registered. + /// + /// This is useful for replacing a module after initialization, e.g. + /// for finetuning. + template + std::shared_ptr replace_module( + const std::string& name, + ModuleHolder module_holder); + + /// Unregisters a submodule from this `Module`. If there is no such module + /// with `name` an exception is thrown. + void unregister_module(const std::string& name); + + protected: + /// The following three functions allow a module with default arguments in its + /// forward method to be used in a Sequential module. + /// You should NEVER override these functions manually. Instead, you should + /// use the `FORWARD_HAS_DEFAULT_ARGS` macro. + virtual bool _forward_has_default_args() { + return false; + } + + virtual unsigned int _forward_num_required_args() { + TORCH_CHECK( + false, + "torch::nn::Module subclass that has default arguments in `forward` method ", + "must override `_forward_num_required_args` method. Please use ", + "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); + } + + virtual std::vector _forward_populate_default_args( + std::vector&& arguments) { + TORCH_CHECK( + false, + "torch::nn::Module subclass that has default arguments in `forward` method ", + "must override `_forward_populate_default_args` method. Please use ", + "`FORWARD_HAS_DEFAULT_ARGS` macro to do so."); + } + + /// The registered parameters of this `Module`. + /// Inorder to access parameters_ in ParameterDict and ParameterList + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + OrderedDict parameters_; + + private: + // Friend classes. + + template + friend class Cloneable; + + template + friend struct AnyModuleHolder; + + /// Pretty prints the given `Module` into the `ostream`. + TORCH_API friend std::ostream& operator<<( + std::ostream& stream, + const nn::Module& module); + + // data parallel using this method to configure gradient edges during the + // replicate step. + template + friend void replicate_grad_edges( + const std::shared_ptr& module, + const std::vector>& replicas, + const std::vector& devices); + + // Private methods. + + /// Used in the implementation of `Cloneable`. + virtual void clone_(Module& other, const std::optional& device); + + /// The implementation of the various `to()` methods. + template + void to_impl(Ts&&... ts); + + /// Implements pretty printing the module hierarchy. + void pretty_print_recursive( + std::ostream& stream, + const std::string& indentation) const; + + /// Applies the `function` to every submodule recursively, starting at this + /// `Module`'s children (thus not including the module itself). + void apply_to_submodules( + const NamedModulePointerApplyFunction& function, + const std::string& name_prefix = std::string()) const; + + /// Returns a shared_ptr to `this` in a safe (checked) way. + std::shared_ptr shared_from_this_checked() const; + + /// The registered buffers of this `Module`. + OrderedDict buffers_; + + /// The registered (direct) submodules of this `Module`. + OrderedDict> children_; + + /// The module's name (e.g. "LSTM"). + mutable std::optional name_; + + /// Whether the module is in training mode. + bool is_training_{true}; +}; + +/// Serialize a `Module` pointer into an `OutputArchive`. +TORCH_API serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const std::shared_ptr& module); + +/// Deserializes a `Module` from an `InputArchive`. +TORCH_API serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + const std::shared_ptr& module); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +typename ModuleType::ContainedType* Module::as() noexcept { + // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for + // `Linear`, since `LinearImpl` inherits `nn::Module`. + return as(); +} + +template +const typename ModuleType::ContainedType* Module::as() const noexcept { + // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for + // `Linear`, since `LinearImpl` inherits `nn::Module`. + return as(); +} + +template +ModuleType* Module::as() noexcept { + return dynamic_cast(this); +} + +template +const ModuleType* Module::as() const noexcept { + return dynamic_cast(this); +} + +template +std::shared_ptr Module::register_module( + std::string name, + std::shared_ptr module) { + TORCH_CHECK(!name.empty(), "Submodule name must not be empty"); + TORCH_CHECK( + name.find('.') == std::string::npos, + "Submodule name must not contain a dot (got '", + name, + "')"); + auto& base_module = children_.insert(std::move(name), std::move(module)); + return std::dynamic_pointer_cast(base_module); +} + +template +std::shared_ptr Module::register_module( + std::string name, + ModuleHolder module_holder) { + return register_module(std::move(name), module_holder.ptr()); +} + +template +std::shared_ptr Module::replace_module( + const std::string& name, + std::shared_ptr module) { + auto& base_module = (children_[name] = std::move(module)); + return std::dynamic_pointer_cast(base_module); +} + +template +std::shared_ptr Module::replace_module( + const std::string& name, + ModuleHolder module_holder) { + return replace_module(name, module_holder.ptr()); +} + +template +void Module::to_impl(Ts&&... ts) { + // First call `to()` on every child module. + for (auto& child : children_) { + child.value()->to(ts...); + } + // Then move every parameter to the new dtype/device. + for (auto& parameter : named_parameters(/*recurse=*/false)) { + parameter->set_data(parameter->to(ts...)); + } + // Then move every buffer to the new dtype/device. + for (auto& buffer : named_buffers(/*recurse=*/false)) { + buffer->set_data(buffer->to(ts...)); + } +} + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules.h new file mode 100644 index 0000000000000000000000000000000000000000..e037d52a8535490ff5ecb17e578df5b4101ee9a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules.h @@ -0,0 +1,36 @@ +#pragma once + +// Common +#include + +// Containers +#include +#include +#include +#include +#include +#include +#include +#include + +// Layers +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/_functions.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..f7cc8d0eb935480bd6b4374506310d1a0cbf0dc6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/_functions.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn::functions { + +class CrossMapLRN2d : public torch::autograd::Function { + public: + static torch::autograd::Variable forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const CrossMapLRN2dOptions& options); + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output); +}; + +} // namespace torch::nn::functions diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/activation.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/activation.h new file mode 100644 index 0000000000000000000000000000000000000000..806fbd2f0f876b57e3e0cd19b806a008ce37f90d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/activation.h @@ -0,0 +1,873 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies elu over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ELU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ELUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ELU model(ELUOptions().alpha(42.42).inplace(true)); +/// ``` +class TORCH_API ELUImpl : public torch::nn::Cloneable { + public: + explicit ELUImpl(const ELUOptions& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `ELU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ELUOptions options; +}; + +/// A `ModuleHolder` subclass for `ELUImpl`. +/// See the documentation for `ELUImpl` class to learn what methods it +/// provides, and examples of how to use `ELU` with `torch::nn::ELUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ELU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the selu function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.SELU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SELUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// SELU model(SELUOptions().inplace(true)); +/// ``` +class TORCH_API SELUImpl : public torch::nn::Cloneable { + public: + explicit SELUImpl(const SELUOptions& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `SELU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + SELUOptions options; +}; + +/// A `ModuleHolder` subclass for `SELUImpl`. +/// See the documentation for `SELUImpl` class to learn what methods it +/// provides, and examples of how to use `SELU` with `torch::nn::SELUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(SELU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the hard shrinkage function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Hardshrink to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::HardshrinkOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Hardshrink model(HardshrinkOptions().lambda(42.42)); +/// ``` +class TORCH_API HardshrinkImpl : public torch::nn::Cloneable { + public: + explicit HardshrinkImpl(const HardshrinkOptions& options_ = {}); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Hardshrink` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + HardshrinkOptions options; +}; + +/// A `ModuleHolder` subclass for `HardshrinkImpl`. +/// See the documentation for `HardshrinkImpl` class to learn what methods it +/// provides, and examples of how to use `Hardshrink` with +/// `torch::nn::HardshrinkOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Hardshrink); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the HardTanh function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Hardtanh to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::HardtanhOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Hardtanh +/// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); +/// ``` +class TORCH_API HardtanhImpl : public torch::nn::Cloneable { + public: + explicit HardtanhImpl(const HardtanhOptions& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `Hardtanh` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + HardtanhOptions options; +}; + +/// A `ModuleHolder` subclass for `HardtanhImpl`. +/// See the documentation for `HardtanhImpl` class to learn what methods it +/// provides, and examples of how to use `Hardtanh` with +/// `torch::nn::HardtanhOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Hardtanh); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LeakyReLU function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LeakyReLU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LeakyReLUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LeakyReLU model(LeakyReLUOptions().negative_slope(0.42).inplace(true)); +/// ``` +class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable { + public: + explicit LeakyReLUImpl(const LeakyReLUOptions& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `LeakyReLU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + LeakyReLUOptions options; +}; + +/// A `ModuleHolder` subclass for `LeakyReLUImpl`. +/// See the documentation for `LeakyReLUImpl` class to learn what methods it +/// provides, and examples of how to use `LeakyReLU` with +/// `torch::nn::LeakyReLUOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LeakyReLU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LogSigmoid function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LogSigmoid to learn +/// about the exact behavior of this module. +class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `LogSigmoid` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `LogSigmoidImpl`. +/// See the documentation for `LogSigmoidImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(LogSigmoid); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the Softmax function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmax to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SoftmaxOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Softmax model(SoftmaxOptions(1)); +/// ``` +class TORCH_API SoftmaxImpl : public torch::nn::Cloneable { + public: + explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {} + explicit SoftmaxImpl(const SoftmaxOptions& options_); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Softmax` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + SoftmaxOptions options; +}; + +/// A `ModuleHolder` subclass for `SoftmaxImpl`. +/// See the documentation for `SoftmaxImpl` class to learn what methods it +/// provides, and examples of how to use `Softmax` with +/// `torch::nn::SoftmaxOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Softmax); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the Softmin function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmin to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SoftminOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Softmin model(SoftminOptions(1)); +/// ``` +class TORCH_API SoftminImpl : public torch::nn::Cloneable { + public: + explicit SoftminImpl(int64_t dim) : SoftminImpl(SoftminOptions(dim)) {} + explicit SoftminImpl(const SoftminOptions& options_); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Softmin` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + SoftminOptions options; +}; + +/// A `ModuleHolder` subclass for `SoftminImpl`. +/// See the documentation for `SoftminImpl` class to learn what methods it +/// provides, and examples of how to use `Softmin` with +/// `torch::nn::SoftminOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Softmin); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSoftmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LogSoftmax function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LogSoftmax to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LogSoftmaxOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LogSoftmax model(LogSoftmaxOptions(1)); +/// ``` +class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable { + public: + explicit LogSoftmaxImpl(int64_t dim) + : LogSoftmaxImpl(LogSoftmaxOptions(dim)) {} + explicit LogSoftmaxImpl(const LogSoftmaxOptions& options_); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `LogSoftmax` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + LogSoftmaxOptions options; +}; + +/// A `ModuleHolder` subclass for `LogSoftmaxImpl`. +/// See the documentation for `LogSoftmaxImpl` class to learn what methods it +/// provides, and examples of how to use `LogSoftmax` with +/// `torch::nn::LogSoftmaxOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LogSoftmax); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the Softmax2d function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmax2d to learn +/// about the exact behavior of this module. +class TORCH_API Softmax2dImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Softmax2d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `Softmax2dImpl`. +/// See the documentation for `Softmax2dImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Softmax2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the PReLU function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.PReLU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::PReLUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// PReLU model(PReLUOptions().num_parameters(42)); +/// ``` +class TORCH_API PReLUImpl : public torch::nn::Cloneable { + public: + explicit PReLUImpl(const PReLUOptions& options_ = {}); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `PReLU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + PReLUOptions options; + + /// The learned weight. + Tensor weight; +}; + +/// A `ModuleHolder` subclass for `PReLUImpl`. +/// See the documentation for `PReLUImpl` class to learn what methods it +/// provides, and examples of how to use `PReLU` with `torch::nn::PReLUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(PReLU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the ReLU function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReLU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReLUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReLU model(ReLUOptions().inplace(true)); +/// ``` +class TORCH_API ReLUImpl : public torch::nn::Cloneable { + public: + explicit ReLUImpl(const ReLUOptions& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `ReLU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ReLUOptions options; +}; + +/// A `ModuleHolder` subclass for `ReLUImpl`. +/// See the documentation for `ReLUImpl` class to learn what methods it +/// provides, and examples of how to use `ReLU` with `torch::nn::ReLUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ReLU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU6 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the ReLU6 function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReLU6 to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReLU6Options` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReLU6 model(ReLU6Options().inplace(true)); +/// ``` +class TORCH_API ReLU6Impl : public torch::nn::Cloneable { + public: + explicit ReLU6Impl(const ReLU6Options& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `ReLU6` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ReLU6Options options; +}; + +/// A `ModuleHolder` subclass for `ReLU6Impl`. +/// See the documentation for `ReLU6Impl` class to learn what methods it +/// provides, and examples of how to use `ReLU6` with `torch::nn::ReLU6Options`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ReLU6); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the RReLU function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.RReLU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::RReLUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// RReLU model(RReLUOptions().lower(0.24).upper(0.42).inplace(true)); +/// ``` +class TORCH_API RReLUImpl : public torch::nn::Cloneable { + public: + explicit RReLUImpl(const RReLUOptions& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `RReLU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + RReLUOptions options; +}; + +/// A `ModuleHolder` subclass for `RReLUImpl`. +/// See the documentation for `RReLUImpl` class to learn what methods it +/// provides, and examples of how to use `RReLU` with `torch::nn::RReLUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(RReLU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies celu over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.CELU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CELUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CELU model(CELUOptions().alpha(42.42).inplace(true)); +/// ``` +class TORCH_API CELUImpl : public torch::nn::Cloneable { + public: + explicit CELUImpl(const CELUOptions& options_ = {}); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `CELU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + CELUOptions options; +}; + +/// A `ModuleHolder` subclass for `CELUImpl`. +/// See the documentation for `CELUImpl` class to learn what methods it +/// provides, and examples of how to use `CELU` with `torch::nn::CELUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(CELU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies glu over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.GLU to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::GLUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// GLU model(GLUOptions(1)); +/// ``` +class TORCH_API GLUImpl : public torch::nn::Cloneable { + public: + explicit GLUImpl(const GLUOptions& options_ = {}); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `GLU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + GLUOptions options; +}; + +/// A `ModuleHolder` subclass for `GLUImpl`. +/// See the documentation for `GLUImpl` class to learn what methods it +/// provides, and examples of how to use `GLU` with `torch::nn::GLUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(GLU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies gelu over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.GELU to learn +/// about the exact behavior of this module. +class TORCH_API GELUImpl : public torch::nn::Cloneable { + public: + explicit GELUImpl(GELUOptions options_ = {}); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `GELU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + GELUOptions options; +}; + +/// A `ModuleHolder` subclass for `GELUImpl`. +/// See the documentation for `GELUImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(GELU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies silu over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.SiLU to learn +/// about the exact behavior of this module. +class TORCH_API SiLUImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `SiLU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `SiLUImpl`. +/// See the documentation for `SiLUImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(SiLU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Mish ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies mish over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Mish to learn +/// about the exact behavior of this module. +class TORCH_API MishImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Mish` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `MishImpl`. +/// See the documentation for `MishImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Mish); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies sigmoid over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Sigmoid to learn +/// about the exact behavior of this module. +class TORCH_API SigmoidImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Sigmoid` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `SigmoidImpl`. +/// See the documentation for `SigmoidImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Sigmoid); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softplus ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies softplus over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Softplus to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SoftplusOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Softplus model(SoftplusOptions().beta(0.24).threshold(42.42)); +/// ``` +class TORCH_API SoftplusImpl : public torch::nn::Cloneable { + public: + explicit SoftplusImpl(const SoftplusOptions& options_ = {}); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Softplus` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + SoftplusOptions options; +}; + +/// A `ModuleHolder` subclass for `SoftplusImpl`. +/// See the documentation for `SoftplusImpl` class to learn what methods it +/// provides, and examples of how to use `Softplus` with +/// `torch::nn::SoftplusOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Softplus); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the soft shrinkage function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Softshrink to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SoftshrinkOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Softshrink model(SoftshrinkOptions(42.42)); +/// ``` +class TORCH_API SoftshrinkImpl : public torch::nn::Cloneable { + public: + explicit SoftshrinkImpl(const SoftshrinkOptions& options_ = {}); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Softshrink` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + SoftshrinkOptions options; +}; + +/// A `ModuleHolder` subclass for `SoftshrinkImpl`. +/// See the documentation for `SoftshrinkImpl` class to learn what methods it +/// provides, and examples of how to use `Softshrink` with +/// `torch::nn::SoftshrinkOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Softshrink); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softsign ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Softsign over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Softsign to learn +/// about the exact behavior of this module. +class TORCH_API SoftsignImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Softsign` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `SoftsignImpl`. +/// See the documentation for `SoftsignImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Softsign); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Tanh over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Tanh to learn +/// about the exact behavior of this module. +class TORCH_API TanhImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Tanh` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `TanhImpl`. +/// See the documentation for `TanhImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Tanh); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanhshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Tanhshrink over a given input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Tanhshrink to learn +/// about the exact behavior of this module. +class TORCH_API TanhshrinkImpl : public torch::nn::Cloneable { + public: + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `Tanhshrink` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `TanhshrinkImpl`. +/// See the documentation for `TanhshrinkImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Tanhshrink); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Threshold ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the Threshold function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Threshold to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ThresholdOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Threshold model(ThresholdOptions(42.42, 24.24).inplace(true)); +/// ``` +class TORCH_API ThresholdImpl : public torch::nn::Cloneable { + public: + ThresholdImpl(double threshold, double value) + : ThresholdImpl(ThresholdOptions(threshold, value)) {} + explicit ThresholdImpl(const ThresholdOptions& options_); + + Tensor forward(Tensor input); + + void reset() override; + + /// Pretty prints the `Threshold` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ThresholdOptions options; +}; + +/// A `ModuleHolder` subclass for `ThresholdImpl`. +/// See the documentation for `ThresholdImpl` class to learn what methods it +/// provides, and examples of how to use `Threshold` with +/// `torch::nn::ThresholdOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Threshold); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the MultiheadAttention function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MultiheadAttention +/// to learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MultiheadAttentionOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MultiheadAttention model(MultiheadAttentionOptions(20, 10).bias(false)); +/// ``` +class TORCH_API MultiheadAttentionImpl + : public torch::nn::Cloneable { + public: + MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads) + : MultiheadAttentionImpl( + MultiheadAttentionOptions(embed_dim, num_heads)) {} + explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_); + + std::tuple forward( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const Tensor& key_padding_mask = {}, + bool need_weights = true, + const Tensor& attn_mask = {}, + bool average_attn_weights = true); + + protected: + FORWARD_HAS_DEFAULT_ARGS( + {3, AnyValue(Tensor())}, + {4, AnyValue(true)}, + {5, AnyValue(Tensor())}, + {6, AnyValue(true)}) + + public: + void reset() override; + + void _reset_parameters(); + + /// The options with which this `Module` was constructed. + MultiheadAttentionOptions options; + + bool _qkv_same_embed_dim{}; + Tensor in_proj_weight; + Tensor in_proj_bias; + Tensor bias_k; + Tensor bias_v; + Linear out_proj = nullptr; + Tensor q_proj_weight; + Tensor k_proj_weight; + Tensor v_proj_weight; + int64_t head_dim{}; +}; + +/// A `ModuleHolder` subclass for `MultiheadAttentionImpl`. +/// See the documentation for `MultiheadAttentionImpl` class to learn what +/// methods it provides, and examples of how to use `MultiheadAttention` with +/// `torch::nn::MultiheadAttentionOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MultiheadAttention); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/adaptive.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/adaptive.h new file mode 100644 index 0000000000000000000000000000000000000000..7833b01297d2d126d43563f36d4e21fa1ff59470 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/adaptive.h @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace torch::nn { + +/// The output of a single invocation of an AdaptiveLogSoftmaxWithLoss +/// module's `forward()` method. +struct TORCH_API ASMoutput { + ASMoutput(Tensor output_, double loss_); + + /// Tensor containing computed target log probabilities for each example + Tensor output; + + /// Scalar representing the computed negative log likelihood loss + double loss; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Efficient softmax approximation as described in +/// `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin, +/// Moustapha Cissé, David Grangier, and Hervé Jégou. +/// See +/// https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss +/// to learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveLogSoftmaxWithLossOptions` +/// class to learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, +/// {4, 8}).div_value(2.).head_bias(true)); +/// ``` +class TORCH_API AdaptiveLogSoftmaxWithLossImpl + : public Cloneable { + public: + AdaptiveLogSoftmaxWithLossImpl( + int64_t in_features, + int64_t n_classes, + std::vector cutoffs) + : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions( + in_features, + n_classes, + std::move(cutoffs))) {} + + explicit AdaptiveLogSoftmaxWithLossImpl( + AdaptiveLogSoftmaxWithLossOptions options_); + + ASMoutput forward(const Tensor& input, const Tensor& target); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `AdaptiveLogSoftmaxWithLoss` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Given input tensor, and output of `head`, computes the log of the full + /// distribution + Tensor _get_full_log_prob(const Tensor& input, const Tensor& head_output); + + /// Computes log probabilities for all n_classes + Tensor log_prob(const Tensor& input); + + /// This is equivalent to `log_pob(input).argmax(1)` but is more efficient in + /// some cases + Tensor predict(const Tensor& input); + + /// The options with which this `Module` was constructed + AdaptiveLogSoftmaxWithLossOptions options; + + /// Cutoffs used to assign targets to their buckets. It should be an ordered + /// Sequence of integers sorted in the increasing order + std::vector cutoffs; + + int64_t shortlist_size; + + /// Number of clusters + int64_t n_clusters; + + /// Output size of head classifier + int64_t head_size; + + Linear head = nullptr; + + ModuleList tail; +}; + +/// A `ModuleHolder` subclass for `AdaptiveLogSoftmaxWithLossImpl`. +/// See the documentation for `AdaptiveLogSoftmaxWithLossImpl` class to learn +/// what methods it provides, and examples of how to use +/// `AdaptiveLogSoftmaxWithLoss` with +/// `torch::nn::AdaptiveLogSoftmaxWithLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveLogSoftmaxWithLoss); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/batchnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..8437ffd7afb8ecd447eba13e92d3bfdf4b465db5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -0,0 +1,242 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::nn { + +/// Base class for all (dimension-specialized) batchnorm and instancenorm +/// modules. +template +class NormImplBase : public torch::nn::Cloneable { + protected: + virtual void _check_input_dim(const Tensor& input) = 0; + + public: + NormImplBase(const DerivedOptions& options_) : options(options_) { + NormImplBase::reset(); + } + + void reset() override { + if (options.affine()) { + weight = this->register_parameter( + "weight", torch::empty({options.num_features()})); + bias = this->register_parameter( + "bias", torch::empty({options.num_features()})); + } else { + weight = + this->register_parameter("weight", Tensor(), /*requires_grad=*/false); + bias = + this->register_parameter("bias", Tensor(), /*requires_grad=*/false); + } + if (options.track_running_stats()) { + running_mean = this->register_buffer( + "running_mean", torch::zeros({options.num_features()})); + running_var = this->register_buffer( + "running_var", torch::ones({options.num_features()})); + num_batches_tracked = this->register_buffer( + "num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong))); + } else { + running_mean = this->register_buffer("running_mean", Tensor()); + running_var = this->register_buffer("running_var", Tensor()); + num_batches_tracked = + this->register_buffer("num_batches_tracked", Tensor()); + } + reset_parameters(); + } + + void reset_running_stats() { + if (options.track_running_stats()) { + running_mean.zero_(); + running_var.fill_(1); + num_batches_tracked.zero_(); + } + } + + void reset_parameters() { + reset_running_stats(); + if (options.affine()) { + torch::nn::init::ones_(weight); + torch::nn::init::zeros_(bias); + } + } + + /// The options with which this module was constructed. + DerivedOptions options; + + /// The learned weight. + /// Only defined if the `affine` option was `true` upon construction. + Tensor weight; + + /// The learned bias. + /// Only defined if the `affine` option was `true` upon construction. + Tensor bias; + + /// The running mean. + /// Only defined if the `track_running_stats` option was `true` upon + /// construction. + Tensor running_mean; + + /// The running variance. + /// Only defined if the `track_running_stats` option was `true` upon + /// construction. + Tensor running_var; + + /// The number of the forward call. + /// Only defined if the `track_running_stats` option was `true` upon + /// construction. + Tensor num_batches_tracked; +}; + +/// Base class for all (dimension-specialized) batchnorm modules. +template +class BatchNormImplBase : public NormImplBase { + public: + using NormImplBase::NormImplBase; + + Tensor forward(const Tensor& input) { + this->_check_input_dim(input); + double exponential_average_factor = 0.0; + if (this->options.momentum().has_value()) { + exponential_average_factor = this->options.momentum().value(); + } + + if (this->is_training() && this->options.track_running_stats()) { + if (this->num_batches_tracked.defined()) { + this->num_batches_tracked += 1; + if (this->options.momentum() == + std::nullopt) { // use cumulative moving average + exponential_average_factor = + 1.0 / this->num_batches_tracked.template item(); + } else { // use exponential moving average + exponential_average_factor = this->options.momentum().value(); + } + } + } + + return torch::nn::functional::detail::batch_norm( + input, + this->running_mean, + this->running_var, + this->weight, + this->bias, + this->is_training() || !this->options.track_running_stats(), + /*momentum=*/exponential_average_factor, + this->options.eps()); + } + + /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << std::boolalpha << "torch::nn::BatchNorm" << D << "d(" + << this->options.num_features() << ", " + << "eps=" << this->options.eps() << ", " + << "momentum="; + + if (this->options.momentum().has_value()) { + stream << this->options.momentum().value(); + } else { + stream << "None"; + } + + stream << ", " + << "affine=" << this->options.affine() << ", " + << "track_running_stats=" << this->options.track_running_stats() + << ")"; + } +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the BatchNorm1d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BatchNorm1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// BatchNorm1d +/// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase; +}; + +/// A `ModuleHolder` subclass for `BatchNorm1dImpl`. +/// See the documentation for `BatchNorm1dImpl` class to learn what methods it +/// provides, and examples of how to use `BatchNorm1d` with +/// `torch::nn::BatchNorm1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(BatchNorm1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the BatchNorm2d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BatchNorm2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// BatchNorm2d +/// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using BatchNormImplBase<2, BatchNorm2dImpl>::BatchNormImplBase; +}; + +/// A `ModuleHolder` subclass for `BatchNorm2dImpl`. +/// See the documentation for `BatchNorm2dImpl` class to learn what methods it +/// provides, and examples of how to use `BatchNorm2d` with +/// `torch::nn::BatchNorm2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(BatchNorm2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the BatchNorm3d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BatchNorm3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// BatchNorm3d +/// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using BatchNormImplBase<3, BatchNorm3dImpl>::BatchNormImplBase; +}; + +/// A `ModuleHolder` subclass for `BatchNorm3dImpl`. +/// See the documentation for `BatchNorm3dImpl` class to learn what methods it +/// provides, and examples of how to use `BatchNorm3d` with +/// `torch::nn::BatchNorm3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(BatchNorm3d); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/common.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/common.h new file mode 100644 index 0000000000000000000000000000000000000000..e967e2317187299a3b1529a29820daffedc53b91 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/common.h @@ -0,0 +1,99 @@ +#pragma once + +/// This macro enables a module with default arguments in its forward method +/// to be used in a Sequential module. +/// +/// Example usage: +/// +/// Let's say we have a module declared like this: +/// ``` +/// struct MImpl : torch::nn::Module { +/// public: +/// explicit MImpl(int value_) : value(value_) {} +/// torch::Tensor forward(int a, int b = 2, double c = 3.0) { +/// return torch::tensor(a + b + c); +/// } +/// private: +/// int value; +/// }; +/// TORCH_MODULE(M); +/// ``` +/// +/// If we try to use it in a Sequential module and run forward: +/// ``` +/// torch::nn::Sequential seq(M(1)); +/// seq->forward(1); +/// ``` +/// +/// We will receive the following error message: +/// ``` +/// MImpl's forward() method expects 3 argument(s), but received 1. +/// If MImpl's forward() method has default arguments, please make sure +/// the forward() method is declared with a corresponding +/// `FORWARD_HAS_DEFAULT_ARGS` macro. +/// ``` +/// +/// The right way to fix this error is to use the `FORWARD_HAS_DEFAULT_ARGS` +/// macro when declaring the module: +/// ``` +/// struct MImpl : torch::nn::Module { +/// public: +/// explicit MImpl(int value_) : value(value_) {} +/// torch::Tensor forward(int a, int b = 2, double c = 3.0) { +/// return torch::tensor(a + b + c); +/// } +/// protected: +/// /* +/// NOTE: looking at the argument list of `forward`: +/// `forward(int a, int b = 2, double c = 3.0)` +/// we saw the following default arguments: +/// ---------------------------------------------------------------- +/// 0-based index of default | Default value of arg +/// arg in forward arg list | (wrapped by `torch::nn::AnyValue()`) +/// ---------------------------------------------------------------- +/// 1 | torch::nn::AnyValue(2) +/// 2 | torch::nn::AnyValue(3.0) +/// ---------------------------------------------------------------- +/// Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS` +/// macro: +/// */ +/// FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2, +/// torch::nn::AnyValue(3.0)}) +/// private: +/// int value; +/// }; +/// TORCH_MODULE(M); +/// ``` +/// Now, running the following would work: +/// ``` +/// torch::nn::Sequential seq(M(1)); +/// seq->forward(1); // This correctly populates the default arguments for +/// `MImpl::forward` +/// ``` +#define FORWARD_HAS_DEFAULT_ARGS(...) \ + template \ + friend struct torch::nn::AnyModuleHolder; \ + bool _forward_has_default_args() override { \ + return true; \ + } \ + unsigned int _forward_num_required_args() override { \ + std::vector> args_info{ \ + __VA_ARGS__}; \ + return std::begin(args_info)->first; \ + } \ + std::vector _forward_populate_default_args( \ + std::vector&& arguments) override { \ + std::vector> args_info{ \ + __VA_ARGS__}; \ + unsigned int num_all_args = std::rbegin(args_info)->first + 1; \ + TORCH_INTERNAL_ASSERT( \ + arguments.size() >= _forward_num_required_args() && \ + arguments.size() <= num_all_args); \ + std::vector ret = std::move(arguments); \ + ret.reserve(num_all_args); \ + for (auto& arg_info : args_info) { \ + if (arg_info.first > ret.size() - 1) \ + ret.emplace_back(std::move(arg_info.second)); \ + } \ + return ret; \ + } diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h new file mode 100644 index 0000000000000000000000000000000000000000..28f297388757bab2896c43e0eb6af032a24d420d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h @@ -0,0 +1,362 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace torch::nn { + +/// Stores a type erased `Module`. +/// +/// The PyTorch C++ API does not impose an interface on the signature of +/// `forward()` in `Module` subclasses. This gives you complete freedom to +/// design your `forward()` methods to your liking. However, this also means +/// there is no unified base type you could store in order to call `forward()` +/// polymorphically for any module. This is where the `AnyModule` comes in. +/// Instead of inheritance, it relies on type erasure for polymorphism. +/// +/// An `AnyModule` can store any `nn::Module` subclass that provides a +/// `forward()` method. This `forward()` may accept any types and return any +/// type. Once stored in an `AnyModule`, you can invoke the underlying module's +/// `forward()` by calling `AnyModule::forward()` with the arguments you would +/// supply to the stored module (though see one important limitation below). +/// Example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// struct GenericTrainer { +/// torch::nn::AnyModule module; +/// +/// void train(torch::Tensor input) { +/// module.forward(input); +/// } +/// }; +/// +/// GenericTrainer trainer1{torch::nn::Linear(3, 4)}; +/// GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)}; +/// \endrst +/// +/// As `AnyModule` erases the static type of the stored module (and its +/// `forward()` method) to achieve polymorphism, type checking of arguments is +/// moved to runtime. That is, passing an argument with an incorrect type to an +/// `AnyModule` will compile, but throw an exception at runtime: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); +/// // Linear takes a tensor as input, but we are passing an integer. +/// // This will compile, but throw a `torch::Error` exception at runtime. +/// module.forward(123); +/// \endrst +/// +/// \rst +/// .. attention:: +/// One noteworthy limitation of `AnyModule` is that its `forward()` method +/// does not support implicit conversion of argument types. For example, if +/// the stored module's `forward()` method accepts a `float` and you call +/// `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw +/// an exception. +/// \endrst +/// +/// The return type of the `AnyModule`'s `forward()` method is controlled via +/// the first template argument to `AnyModule::forward()`. It defaults to +/// `torch::Tensor`. To change it, you can write `any_module.forward()`, +/// for example. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); +/// auto output = module.forward(torch::ones({2, 3})); +/// +/// struct IntModule { +/// int forward(int x) { return x; } +/// }; +/// torch::nn::AnyModule module(IntModule{}); +/// int output = module.forward(5); +/// \endrst +/// +/// The only other method an `AnyModule` provides access to on the stored +/// module is `clone()`. However, you may acquire a handle on the module via +/// `.ptr()`, which returns a `shared_ptr`. Further, if you know +/// the concrete type of the stored module, you can get a concrete handle to it +/// using `.get()` where `T` is the concrete module type. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::AnyModule module(torch::nn::Linear(3, 4)); +/// std::shared_ptr ptr = module.ptr(); +/// torch::nn::Linear linear(module.get()); +/// \endrst +class AnyModule { + public: + /// A default-constructed `AnyModule` is in an empty state. + AnyModule() = default; + + /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object. + template + explicit AnyModule(std::shared_ptr module); + + /// Constructs an `AnyModule` from a concrete module object. + template < + typename ModuleType, + typename = torch::detail::enable_if_module_t> + explicit AnyModule(ModuleType&& module); + + /// Constructs an `AnyModule` from a module holder. + template + explicit AnyModule(const ModuleHolder& module_holder); + + /// Move construction and assignment is allowed, and follows the default + /// behavior of move for `std::unique_ptr`. + AnyModule(AnyModule&&) = default; + AnyModule& operator=(AnyModule&&) = default; + + /// Creates a shallow copy of an `AnyModule`. + AnyModule(const AnyModule& other); + AnyModule& operator=(const AnyModule& other); + + /// Creates a deep copy of an `AnyModule` if it contains a module, else an + /// empty `AnyModule` if it is empty. + AnyModule clone(std::optional device = std::nullopt) const; + + /// Assigns a module to the `AnyModule` (to circumvent the explicit + /// constructor). + template + AnyModule& operator=(std::shared_ptr module); + + /// Invokes `forward()` on the contained module with the given arguments, and + /// returns the return value as an `AnyValue`. Use this method when chaining + /// `AnyModule`s in a loop. + template + AnyValue any_forward(ArgumentTypes&&... arguments); + + /// Invokes `forward()` on the contained module with the given arguments, and + /// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults + /// to `torch::Tensor`). + template + ReturnType forward(ArgumentTypes&&... arguments); + + /// Attempts to cast the underlying module to the given module type. Throws an + /// exception if the types do not match. + template > + T& get(); + + /// Attempts to cast the underlying module to the given module type. Throws an + /// exception if the types do not match. + template > + const T& get() const; + + /// Returns the contained module in a `nn::ModuleHolder` subclass if possible + /// (i.e. if `T` has a constructor for the underlying module type). + template + T get() const; + + /// Returns a `std::shared_ptr` whose dynamic type is that of the underlying + /// module. + std::shared_ptr ptr() const; + + /// Like `ptr()`, but casts the pointer to the given type. + template > + std::shared_ptr ptr() const; + + /// Returns the `type_info` object of the contained value. + const std::type_info& type_info() const; + + /// Returns true if the `AnyModule` does not contain a module. + bool is_empty() const noexcept; + + private: + /// Creates a `unique_ptr` pointing to a + /// `AnyModuleHolder` of the correct type. This method is used to deduce the + /// arguments of the module's `forward()` method. + template < + typename ModuleType, + typename Class, + typename ReturnType, + typename... ArgumentTypes> + std::unique_ptr make_holder( + std::shared_ptr&& module, + ReturnType (Class::*)(ArgumentTypes...)); + + /// Helper method invoked by const and non-const `get()`. + template + ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const; + + /// Helper method invoked by const and non-const `get()`. + template + ModuleType& get_() const; + + /// The type erased module. + std::unique_ptr content_; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +AnyModule::AnyModule(std::shared_ptr module) + : content_(make_holder( + std::move(module), + &std::remove_reference_t::forward)) { + // `AnyModule` can only store an `nn::Module` subclass object that provides + // a `forward()` method that has a non-templatized return type. + // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s + // `forward()` method has a templatized return type.) + static_assert( + torch::detail::is_module::value, + "Can only store object derived from nn::Module into AnyModule"); + static_assert( + torch::detail::has_forward::value, + "Can only store module with a forward() method that has a non-templatized" + " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential" + "into AnyModule, because its forward() method's argument type and return type are templatized." + " If you need to use nn::Sequentials inside each other you can subclass " + "nn::Sequential and write a non-templatized forward function for it. You can checkout " + "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 " + "for an example on how to do this.)."); +} + +template +AnyModule::AnyModule(ModuleType&& module) + : AnyModule( + std::make_shared(std::forward(module))) {} + +template +AnyModule::AnyModule(const ModuleHolder& module_holder) + : AnyModule(module_holder.ptr()) {} + +inline AnyModule::AnyModule(const AnyModule& other) + : content_(other.content_ ? other.content_->copy() : nullptr) {} + +inline AnyModule& AnyModule::operator=(const AnyModule& other) { + if (this != &other) { + content_ = other.content_ ? other.content_->copy() : nullptr; + } + return *this; +} + +inline AnyModule AnyModule::clone(std::optional device) const { + AnyModule clone; + clone.content_ = content_ ? content_->clone_module(device) : nullptr; + return clone; +} + +template +AnyModule& AnyModule::operator=(std::shared_ptr module) { + *this = AnyModule(std::move(module)); + return *this; +} + +template +AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) { + TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule"); + std::vector values; + values.reserve(sizeof...(ArgumentTypes)); + torch::apply( + [&values](AnyValue&& value) { values.push_back(std::move(value)); }, + AnyValue(std::forward(arguments))...); + return content_->forward(std::move(values)); +} + +template +ReturnType AnyModule::forward(ArgumentTypes&&... arguments) { + return any_forward(std::forward(arguments)...) + .template get(); +} + +template +T& AnyModule::get() { + TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule"); + return get_(); +} + +template +const T& AnyModule::get() const { + TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule"); + return get_(); +} + +template +T AnyModule::get() const { + return T(ptr()); +} + +inline std::shared_ptr AnyModule::ptr() const { + TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule"); + return content_->ptr(); +} + +template +std::shared_ptr AnyModule::ptr() const { + TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule"); + // Call get() but discard the value, just to do the type checking. + get_(); + return std::dynamic_pointer_cast(ptr()); +} + +inline const std::type_info& AnyModule::type_info() const { + TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule"); + return content_->type_info; +} + +inline bool AnyModule::is_empty() const noexcept { + return content_ == nullptr; +} + +// Private Methods + +template < + typename ModuleType, + typename Class, + typename ReturnType, + typename... ArgumentTypes> +std::unique_ptr AnyModule::make_holder( + std::shared_ptr&& module, + ReturnType (Class::*)(ArgumentTypes...)) { + static_assert( + torch::detail::check_not_lvalue_references(), + "Modules stored inside AnyModule must not take references. " + "Use pointers instead."); + static_assert( + !std::is_void_v, + "AnyModule cannot store modules that return void " + "(you can return a dummy value)."); + return std::make_unique< + AnyModuleHolder, ArgumentTypes...>>( + std::move(module)); +} + +template +ModuleType& AnyModule::get_() const { + using M = std::remove_reference_t; + static_assert( + torch::detail::has_forward::value, + "Can only call AnyModule::get with a type T that has a forward method"); + return get_(&M::forward); +} + +template +ModuleType& AnyModule::get_( + ReturnType (ModuleType::*)(ArgumentTypes...)) const { + if (typeid(ModuleType).hash_code() == type_info().hash_code()) { + return *static_cast&>( + *content_) + .module; + } + TORCH_CHECK( + false, + "Attempted to cast module of type ", + c10::demangle(type_info().name()), + " to type ", + c10::demangle(typeid(ModuleType).name())); +} + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h new file mode 100644 index 0000000000000000000000000000000000000000..7482ef3b452d9188ad32cc66cce5c57462bd6edc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h @@ -0,0 +1,135 @@ +#pragma once + +#include +#include + +namespace torch::nn { + +class Module; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The static type of the object we store in the `AnyModule`, which erases +/// the actual type, but allows us to call `forward()` on the underlying +/// module. +struct AnyModulePlaceholder : public AnyValue::Placeholder { + using AnyValue::Placeholder::Placeholder; + + /// The "erased" `forward()` method. + virtual AnyValue forward(std::vector&& arguments) = 0; + + /// Returns std::shared_ptr pointing to the erased module. + virtual std::shared_ptr ptr() = 0; + + /// Returns a `AnyModulePlaceholder` with a shallow copy of this `AnyModule`. + virtual std::unique_ptr copy() const = 0; + + /// Returns a `AnyModulePlaceholder` with a deep copy of this `AnyModule`. + virtual std::unique_ptr clone_module( + std::optional device) const = 0; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The dynamic type of the object stored in the `AnyModule`. It contains the +/// concrete instance to which all calls are forwarded. It is parameterized +/// over the concrete type of the module, and the types of the arguments the +/// module takes in its `forward()` method. +template +struct AnyModuleHolder : public AnyModulePlaceholder { + /// \internal + struct CheckedGetter { + template + std::decay_t&& operator()(size_t index) { + AT_ASSERT(index < arguments_.size()); + auto& value = arguments_[index]; + if (auto* maybe_value = value.template try_get>()) { + return std::move(*maybe_value); + } + TORCH_CHECK( + false, + "Expected argument #", + index, + " to be of type ", + c10::demangle(typeid(T).name()), + ", but received value of type ", + c10::demangle(value.type_info().name())); + } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + std::vector& arguments_; + }; + + /// \internal + struct InvokeForward { + template + AnyValue operator()(Ts&&... ts) { + return AnyValue(module_->forward(std::forward(ts)...)); + } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + std::shared_ptr& module_; + }; + + /// Constructs the `AnyModuleHolder` from a concrete module. + explicit AnyModuleHolder(std::shared_ptr&& module_) + : AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {} + + /// Calls `forward()` on the underlying module, casting each `AnyValue` in the + /// argument vector to a concrete value. + AnyValue forward(std::vector&& arguments) override { + if (module->_forward_has_default_args()) { + TORCH_CHECK( + arguments.size() >= module->_forward_num_required_args() && + arguments.size() <= sizeof...(ArgumentTypes), + c10::demangle(type_info.name()), + "'s forward() method expects at least ", + module->_forward_num_required_args(), + " argument(s) and at most ", + sizeof...(ArgumentTypes), + " argument(s), but received ", + arguments.size(), + "."); + arguments = std::move( + module->_forward_populate_default_args(std::move(arguments))); + } else { + std::string use_default_args_macro_prompt = " If " + + c10::demangle(type_info.name()) + + "'s forward() method has default arguments, " + + "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."; + TORCH_CHECK( + arguments.size() == sizeof...(ArgumentTypes), + c10::demangle(type_info.name()), + "'s forward() method expects ", + sizeof...(ArgumentTypes), + " argument(s), but received ", + arguments.size(), + ".", + (arguments.size() < sizeof...(ArgumentTypes)) + ? use_default_args_macro_prompt + : ""); + } + + // FYI: During invocation of a module's `forward()` method, the values live + // in the `arguments` vector inside this function. + return torch::unpack( + InvokeForward{module}, CheckedGetter{arguments}); + } + + std::shared_ptr ptr() override { + return module; + } + + std::unique_ptr copy() const override { + return std::make_unique(*this); + } + + std::unique_ptr clone_module( + std::optional device) const override { + return std::make_unique( + std::dynamic_pointer_cast(module->clone(device))); + } + + /// The actual concrete module instance. + std::shared_ptr module; +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h new file mode 100644 index 0000000000000000000000000000000000000000..bad19a44014501b38da4f542293b22348fe19c8a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h @@ -0,0 +1,124 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// An implementation of `std::any` which stores +/// a type erased object, whose concrete value can be retrieved at runtime by +/// checking if the `typeid()` of a requested type matches the `typeid()` of +/// the object stored. +class AnyValue { + public: + /// Move construction and assignment is allowed, and follows the default + /// behavior of move for `std::unique_ptr`. + AnyValue(AnyValue&&) = default; + AnyValue& operator=(AnyValue&&) = default; + ~AnyValue() = default; + + /// Copy construction and assignment is allowed. + AnyValue(const AnyValue& other) : content_(other.content_->clone()) {} + AnyValue& operator=(const AnyValue& other) { + content_ = other.content_->clone(); + return *this; + } + + /// Constructs the `AnyValue` from value type. + template < + typename T, + typename = std::enable_if_t>> + explicit AnyValue(T&& value) + : content_( + std::make_unique>>(std::forward(value))) { + } + + /// Returns a pointer to the value contained in the `AnyValue` if the type + /// passed as template parameter matches the type of the value stored, and + /// returns a null pointer otherwise. + template + T* try_get() { + static_assert( + !std::is_reference_v, + "AnyValue stores decayed types, you cannot cast it to a reference type"); + static_assert( + !std::is_array_v, + "AnyValue stores decayed types, you must cast it to T* instead of T[]"); + if (typeid(T).hash_code() == type_info().hash_code()) { + return &static_cast&>(*content_).value; + } + return nullptr; + } + + /// Returns the value contained in the `AnyValue` if the type passed as + /// template parameter matches the type of the value stored, and throws an + /// exception otherwise. + template + T get() { + if (auto* maybe_value = try_get()) { + return *maybe_value; + } + TORCH_CHECK( + false, + "Attempted to cast AnyValue to ", + c10::demangle(typeid(T).name()), + ", but its actual type is ", + c10::demangle(type_info().name())); + } + + /// Returns the `type_info` object of the contained value. + const std::type_info& type_info() const noexcept { + return content_->type_info; + } + + private: + friend struct AnyModulePlaceholder; + friend struct TestAnyValue; + + /// \internal + /// The static type of the object we store in the `AnyValue`, which erases the + /// actual object's type, allowing us only to check the `type_info` of the + /// type stored in the dynamic type. + struct Placeholder { + explicit Placeholder(const std::type_info& type_info_) noexcept + : type_info(type_info_) {} + Placeholder(const Placeholder&) = default; + Placeholder(Placeholder&&) = default; + Placeholder& operator=(const Placeholder&) = delete; + Placeholder& operator=(Placeholder&&) = delete; + virtual ~Placeholder() = default; + virtual std::unique_ptr clone() const { + TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`"); + } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::type_info& type_info; + }; + + /// \internal + /// The dynamic type of the object we store in the `AnyValue`, which hides the + /// actual object we have erased in this `AnyValue`. + template + struct Holder : public Placeholder { + /// A template because T&& would not be universal reference here. + template < + typename U, + typename = std::enable_if_t>> + explicit Holder(U&& value_) noexcept + : Placeholder(typeid(T)), value(std::forward(value_)) {} + std::unique_ptr clone() const override { + return std::make_unique>(value); + } + T value; + }; + + /// The type erased object. + std::unique_ptr content_; +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h new file mode 100644 index 0000000000000000000000000000000000000000..fac31d204f5aea1b18f97d2d40520e7880c818cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h @@ -0,0 +1,101 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::nn { + +/// Wraps a function in a `Module`. +/// +/// The `Functional` module allows wrapping an arbitrary function or function +/// object in an `nn::Module`. This is primarily handy for usage in +/// `Sequential`. +/// +/// \rst +/// .. code-block:: cpp +/// +/// Sequential sequential( +/// Linear(3, 4), +/// Functional(torch::relu), +/// BatchNorm1d(3), +/// Functional(torch::elu, /*alpha=*/1)); +/// \endrst +/// +/// While a `Functional` module only accepts a single `Tensor` as input, it is +/// possible for the wrapped function to accept further arguments. However, +/// these have to be bound *at construction time*. For example, if +/// you want to wrap `torch::leaky_relu`, which accepts a `slope` scalar as its +/// second argument, with a particular value for its `slope` in a `Functional` +/// module, you could write +/// +/// \rst +/// .. code-block:: cpp +/// +/// Functional(torch::leaky_relu, /*slope=*/0.5) +/// \endrst +/// +/// The value of `0.5` is then stored within the `Functional` object and +/// supplied to the function call at invocation time. Note that such bound +/// values are evaluated eagerly and stored a single time. See the documentation +/// of [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind) +/// for more information on the semantics of argument binding. +/// +/// \rst +/// .. attention:: +/// After passing any bound arguments, the function must accept a single +/// tensor and return a single tensor. +/// \endrst +/// +/// Note that `Functional` overloads the call operator (`operator()`) such that +/// you can invoke it with `my_func(...)`. +class TORCH_API FunctionalImpl : public torch::nn::Cloneable { + public: + using Function = std::function; + + /// Constructs a `Functional` from a function object. + explicit FunctionalImpl(Function function); + + template < + typename SomeFunction, + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) > 0)>> + explicit FunctionalImpl(SomeFunction original_function, Args&&... args) + // NOLINTNEXTLINE(modernize-avoid-bind) + : function_(std::bind( + original_function, + /*input=*/std::placeholders::_1, + std::forward(args)...)) { + // std::bind is normally evil, but (1) gcc is broken w.r.t. handling + // parameter pack expansion in lambdas and (2) moving parameter packs into + // a lambda only works with C++14, so std::bind is the more move-aware + // solution here. + } + + void reset() override; + + /// Pretty prints the `Functional` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Forwards the `input` tensor to the underlying (bound) function object. + Tensor forward(Tensor input); + + /// Calls forward(input). + Tensor operator()(Tensor input); + + bool is_serializable() const override; + + private: + Function function_; +}; + +/// A `ModuleHolder` subclass for `FunctionalImpl`. +/// See the documentation for `FunctionalImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Functional); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h new file mode 100644 index 0000000000000000000000000000000000000000..246ed8abb633bf9fc7fcafe2f3d8da786d3ff29f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h @@ -0,0 +1,260 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// An OrderedDict of `Module`s that registers its elements by their `key`s. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::OrderedDict> ordereddict = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict1(ordereddict); +/// +/// for (const auto &module : *dict1) { +/// module->pretty_print(std::cout); +/// } +/// +/// std::vector>> list = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict2(list); +/// +/// for (const auto &module : *dict2) { +/// module->pretty_print(std::cout); +/// } +/// +/// \endrst +/// +/// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`? +/// The value a `ModuleDict` provides over manually calling an ordered map of +/// modules is that it allows treating the whole container *as a single module*, +/// such that performing a transformation on the `ModuleDict` applies to each of +/// the modules it stores (which are each a registered submodule of the +/// `ModuleDict`). For example, calling `.to(torch::kCUDA)` on a `ModuleDict` +/// will move each module in the map to CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::OrderedDict> ordereddict = { +/// {"linear", Linear(10, 3).ptr()}, +/// {"conv", Conv2d(1, 2, 3).ptr()}, +/// {"dropout", Dropout(0.5).ptr()}, +/// }; +/// torch::nn::ModuleDict dict(ordereddict); +/// +/// // Convert all modules to CUDA. +/// dict->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `ModuleDict` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding new modules from a +/// vector of key-module pairs or an `OrderedDict` or another `ModuleDict` after +/// construction via `update`. +class ModuleDictImpl : public Cloneable { + public: + using Iterator = + torch::OrderedDict>::Iterator; + using ConstIterator = + torch::OrderedDict>::ConstIterator; + + ModuleDictImpl() = default; + + /// Constructs the `ModuleDict` from a list of string-Module pairs. + explicit ModuleDictImpl( + const std::vector>>& + modules) { + update(modules); + } + + /// Constructs the `ModuleDict` from an `OrderedDict`. + explicit ModuleDictImpl( + const torch::OrderedDict>& modules) { + update(modules); + } + + /// Return the items in the `ModuleDict`. + std::vector>> items() const { + return modules_.pairs(); + } + + /// Return the keys in the `ModuleDict`. + std::vector keys() const { + return modules_.keys(); + } + + /// Return the values in the `ModuleDict`. + std::vector> values() const { + return modules_.values(); + } + + /// Return an iterator to the start of `ModuleDict`. + Iterator begin() { + return modules_.begin(); + } + + /// Return a const iterator to the start of `ModuleDict`. + ConstIterator begin() const { + return modules_.begin(); + } + + /// Return an iterator to the end of `ModuleDict`. + Iterator end() { + return modules_.end(); + } + + /// Return a const iterator to the end of `ModuleDict`. + ConstIterator end() const { + return modules_.end(); + } + + /// Return the number of items currently stored in the `ModuleDict`. + size_t size() const noexcept { + return modules_.size(); + } + + /// Return true if the `ModuleDict` is empty, otherwise return false. + bool empty() const noexcept { + return modules_.is_empty(); + } + + /// Check if the certain parameter with the key in the `ModuleDict`. + bool contains(const std::string& key) const noexcept { + return modules_.contains(key); + } + + /// Remove all items from the `ModuleDict`. + void clear() { + // Not remove the registration of modules to make it consistent with python + // version. + modules_.clear(); + } + + /// Special cloning function for `ModuleDict` because it does not use + /// `reset()`. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->insert(module.key(), module.value()->clone(device)); + } + return clone; + } + + /// `reset()` is empty for `ModuleDict`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `ModuleDict` into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ModuleDict"; + } + + /// Attempts to returns the `Module` associated with the given `key`. Throws + /// an exception if no such `key` is stored in the `ModuleDict`. Check + /// contains(key) before for a non-throwing way of access. + std::shared_ptr operator[](const std::string& key) const { + return modules_[key]; + } + + /// Attempts to return the module at the given key as the requested type. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + template + T& at(const std::string& key) { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + auto module = modules_[key]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + key, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return the module at the given key as the requested type. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + template + const T& at(const std::string& key) const { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + const auto module = modules_[key]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + key, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Removes and returns the `Module` associated with the given `key`. + /// Throws an exception if no such `key` is stored in the `ModuleDict`. + /// Check contains(key) before for a non-throwing way of access. + std::shared_ptr pop(const std::string& key) { + auto module = modules_[key]; + modules_.erase(key); + // Not remove the registration of the module to make it consistent with + // python version. + return module; + } + + /// Updated the `ModuleDict` with a vector of key-module pairs. + void update( + const std::vector>>& + modules) { + for (auto& item : modules) { + insert(item.first, item.second); + } + } + + /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or + /// `ModuleDict`. + template + void update(const Container& container) { + for (auto& item : container) { + insert(item.key(), item.value()); + } + } + + private: + /// Private `OrderedDict` holding the key-Module pairs. + torch::OrderedDict> modules_; + + /// Insert a key-module pair by overwriting existing keys, + /// and register or replace the `Module`. + void insert(const std::string& key, std::shared_ptr module) { + if (contains(key)) { + modules_[key] = std::move(module); + replace_module(key, modules_[key]); + } else { + modules_.insert(key, std::move(module)); + register_module(key, modules_.back().value()); + } + } +}; + +/// A `ModuleHolder` subclass for `ModuleDictImpl`. +/// See the documentation for `ModuleDictImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ModuleDict); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h new file mode 100644 index 0000000000000000000000000000000000000000..6147a73db4b4b92924620f65a76ef602d78751bb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h @@ -0,0 +1,272 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::nn { + +/// A list of `Module`s that registers its elements. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::ModuleList mlist( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// for (const auto &module : *mlist) { +/// module->pretty_print(std::cout); +/// } +/// +/// \endrst +/// +/// Why should you use `ModuleList` instead of a simple `std::vector`? The value +/// a `ModuleList` provides over manually calling a sequence of modules is that +/// it allows treating the whole container *as a single module*, such that +/// performing a transformation on the `ModuleList` applies to each of the +/// modules it stores (which are each a registered submodule of the +/// `ModuleList`). For example, calling +/// `.to(torch::kCUDA)` on a `ModuleList` will move each module in the list to +/// CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::ModuleList mlist( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// // Convert all modules to CUDA. +/// mlist->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `ModuleList` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding a new module after +/// construction via `push_back`, as well as joining two `ModuleList`s via +/// `extend`. +class ModuleListImpl : public Cloneable { + public: + using Iterator = std::vector>::iterator; + using ConstIterator = std::vector>::const_iterator; + + ModuleListImpl() = default; + + /// Constructs the `ModuleList` from a variadic list of modules. + template + explicit ModuleListImpl(Modules&&... modules) { + modules_.reserve(sizeof...(Modules)); + push_back_var(std::forward(modules)...); + } + + /// Special cloning function for `ModuleList` because it does not use + /// `reset()`. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->push_back(module->clone(device)); + } + return clone; + } + + /// `reset()` is empty for `ModuleList`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `ModuleList` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ModuleList"; + } + + void push_back(std::shared_ptr module) { + modules_.push_back(std::move(module)); + const auto index = modules_.size() - 1; + register_module(std::to_string(index), modules_[index]); + } + + /// Adds a new `Module` to the `ModuleList` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void push_back(M&& module) { + using Type = std::remove_reference_t; + push_back(std::make_shared(std::forward(module))); + } + + /// Unwraps the contained module of a `ModuleHolder` and adds it to the + /// `ModuleList`. + template + void push_back(const ModuleHolder& module_holder) { + push_back(module_holder.ptr()); + } + + /// Iterates over the container and calls `push_back()` on each value. + template + void extend(const Container& container) { + for (const auto& module : container) { + push_back(module); + } + } + + /// Returns an iterator to the start of the `ModuleList`. + Iterator begin() { + return modules_.begin(); + } + + /// Returns a const iterator to the start of the `ModuleList`. + ConstIterator begin() const { + return modules_.begin(); + } + + /// Returns an iterator to the end of the `ModuleList`. + Iterator end() { + return modules_.end(); + } + + /// Returns a const iterator to the end of the `ModuleList`. + ConstIterator end() const { + return modules_.end(); + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + T& at(size_t index) { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + auto module = modules_[index]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + index, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + const T& at(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + const auto module = modules_[index]->as(); + TORCH_CHECK( + module, + "Unable to cast module[", + index, + "] to ", + c10::demangle(typeid(T).name())); + return *module; + } + + /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the + /// underlying module at the given index. Throws an exception if the index is + /// out of bounds. + std::shared_ptr ptr(size_t index) const { + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index]; + } + + /// Attempts to return a `std::shared_ptr` whose type is the one provided. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + std::shared_ptr ptr(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call ModuleList::ptr with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return std::dynamic_pointer_cast(modules_[index]); + } + + /// Like `ptr(index)`. + std::shared_ptr operator[](size_t index) const { + // This is the only method we can call without a type. + return ptr(index); + } + + /// The current size of the `ModuleList` container. + size_t size() const noexcept { + return modules_.size(); + } + + /// True if there are no modules in the `ModuleList`. + bool is_empty() const noexcept { + return size() == 0; + } + + void insert(size_t index, std::shared_ptr module) { + TORCH_CHECK(index <= size(), "Index out of range"); + + if (index == size()) + push_back(std::move(module)); + else { + modules_.insert( + modules_.begin() + Iterator::difference_type(index), + std::move(module)); + + for (const auto i : c10::irange(index, size() - 1)) { + (void)i; // Suppress unused variable warning + replace_module(std::to_string(index), modules_[index]); + } + register_module(std::to_string(size() - 1), modules_.back()); + } + } + + /// Unwraps the contained module of a `ModuleHolder` and inserts it in the + /// `ModuleList`. + template + void insert(size_t index, const ModuleHolder& module_holder) { + insert(index, module_holder.ptr()); + } + + /// inserts a new `Module` to the `ModuleList` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void insert(size_t index, M&& module) { + using Type = std::remove_reference_t; + insert(index, std::make_shared(std::forward(module))); + } + + private: + template + void push_back_var(Head&& head, Tail&&... tail) { + push_back(std::forward(head)); + // Recursively calls this method, until the parameter pack only thas this + // entry left. Then calls `push_back()` a final time (above). + push_back_var(std::forward(tail)...); + } + + /// The base case, when the list of modules is empty. + void push_back_var() {} + + // Box the AnyModules to give ModuleList reference semantics, like the rest of + // the API. Note that this is not required otherwise, this could just be a + // `vector`. + std::vector> modules_; +}; + +/// A `ModuleHolder` subclass for `ModuleListImpl`. +/// See the documentation for `ModuleListImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(ModuleList); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h new file mode 100644 index 0000000000000000000000000000000000000000..9b7c01b08e9cfaaf9722e3e23efc4db88ebd0e59 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::nn { + +/// Stores a type erased `Module` with name. +/// +/// The `NamedAnyModule` class enables the following API for constructing +/// `nn::Sequential` with named submodules: +/// \rst +/// .. code-block:: cpp +/// +/// struct M : torch::nn::Module { +/// explicit M(int value_) : value(value_) {} +/// int value; +/// int forward() { +/// return value; +/// } +/// }; +/// +/// Sequential sequential({ +/// {"m1", std::make_shared(1)}, // shared pointer to `Module` is +/// supported {std::string("m2"), M(2)}, // `Module` is supported +/// {"linear1", Linear(10, 3)} // `ModuleHolder` is supported +/// }); +/// \endrst +class NamedAnyModule { + public: + /// Creates a `NamedAnyModule` from a (boxed) `Module`. + template + NamedAnyModule(std::string name, std::shared_ptr module_ptr) + : NamedAnyModule(std::move(name), AnyModule(std::move(module_ptr))) {} + + /// Creates a `NamedAnyModule` from a `Module`, moving or copying it + /// into a `shared_ptr` internally. + // NOTE: We need to use `std::remove_reference_t` to get rid of + // any reference components for make_unique. + template > + NamedAnyModule(std::string name, M&& module) + : NamedAnyModule( + std::move(name), + std::make_shared>( + std::forward(module))) {} + + /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from + /// a `ModuleHolder`. + template + NamedAnyModule(std::string name, const ModuleHolder& module_holder) + : NamedAnyModule(std::move(name), module_holder.ptr()) {} + + /// Creates a `NamedAnyModule` from a type-erased `AnyModule`. + NamedAnyModule(std::string name, AnyModule any_module) + : name_(std::move(name)), module_(std::move(any_module)) {} + + /// Returns a reference to the name. + const std::string& name() const noexcept { + return name_; + } + + /// Returns a reference to the module. + AnyModule& module() noexcept { + return module_; + } + + /// Returns a const reference to the module. + const AnyModule& module() const noexcept { + return module_; + } + + private: + std::string name_; + AnyModule module_; +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h new file mode 100644 index 0000000000000000000000000000000000000000..008d790fdece111a6e7adc8eb18989d8f156949a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h @@ -0,0 +1,146 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::nn { + +class ParameterDictImpl : public Cloneable { + public: + using Iterator = OrderedDict::Iterator; + using ConstIterator = OrderedDict::ConstIterator; + + ParameterDictImpl() = default; + + explicit ParameterDictImpl( + const torch::OrderedDict& params) { + parameters_ = params; + } + + /// `reset()` is empty for `ParameterDict`, since it does not have + /// parameters of its own. + void reset() override {} + + /// Pretty prints the `ParameterDict` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ParameterDict(" << '\n'; + for (const auto& pair : parameters_) { + stream << "(" << pair.key() << ")" + << ": Parameter containing: [" << pair.value().scalar_type() + << " of size " << pair.value().sizes() << "]"; + ; + stream << '\n'; + } + stream << ")"; + } + + /// Insert the parameter along with the key into ParameterDict + /// The parameter is set to be require grad by default + Tensor& insert(const std::string& key, const Tensor& param) { + bool requires_grad = param.requires_grad(); + return register_parameter(key, param, requires_grad); + } + + /// Remove key from the ParameterDict and return its value, throw exception + /// if the key is not contained. Please check contains(key) before for a + /// non-throwing access. + Tensor pop(const std::string& key) { + torch::Tensor v = parameters_[key]; + parameters_.erase(key); + return v; + } + + /// Return the keys in the dict + ::std::vector keys() const { + return parameters_.keys(); + } + + /// Return the Values in the dict + ::std::vector values() const { + return parameters_.values(); + } + + /// Return an iterator to the start of ParameterDict + Iterator begin() { + return parameters_.begin(); + } + + /// Return a const iterator to the start of ParameterDict + ConstIterator begin() const { + return parameters_.begin(); + } + + /// Return an iterator to the end of ParameterDict + Iterator end() { + return parameters_.end(); + } + + /// Return a const iterator to the end of ParameterDict + ConstIterator end() const { + return parameters_.end(); + } + + /// Return the number of items currently stored in the ParameterDict + size_t size() const noexcept { + return parameters_.size(); + } + + /// Return true if the ParameterDict is empty, otherwise return false + bool empty() const noexcept { + return parameters_.is_empty(); + } + + /// Update the ParameterDict with the key-value pairs from + /// another ParameterDict, overwriting existing key + template + void update(const Container& container) { + for (auto& item : container) { + parameters_[item.key()] = item.value(); + } + } + + /// Remove all parameters in the ParameterDict + void clear() { + parameters_.clear(); + } + + /// Check if the certain parameter with the key in the ParameterDict + bool contains(const std::string& key) const noexcept { + return parameters_.contains(key); + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + const Tensor& get(const std::string& key) const { + return parameters_[key]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + Tensor& get(const std::string& key) { + return parameters_[key]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + Tensor& operator[](const std::string& key) { + return parameters_[key]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterDict`. Check contains(key) before + /// for a non-throwing way of access + const Tensor& operator[](const std::string& key) const { + return parameters_[key]; + } +}; + +TORCH_MODULE(ParameterDict); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h new file mode 100644 index 0000000000000000000000000000000000000000..2ea2b52fa0fb969c39ca3bf7f57201105ecdd02d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h @@ -0,0 +1,167 @@ +#pragma once + +#include +#include + +#include + +namespace torch::nn { +class ParameterListImpl : public Cloneable { + public: + using Iterator = typename std::vector< + OrderedDict::Item>::iterator; + using ConstIterator = typename std::vector< + OrderedDict::Item>::const_iterator; + + ParameterListImpl() = default; + + /// Constructs the `ParameterList` from a variadic list of ParameterList. + template + explicit ParameterListImpl(Tensors&&... params) { + parameters_.reserve(sizeof...(Tensors)); + push_back_var(std::forward(params)...); + } + + template + explicit ParameterListImpl(const Tensors&... params) { + parameters_.reserve(sizeof...(Tensors)); + push_back_var(std::forward(params)...); + } + + /// `reset()` is empty for `ParameterList`, since it does not have parameters + /// of its own. + void reset() override {} + + /// Pretty prints the `ParameterList` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ParameterList(" << '\n'; + for (const auto& pair : parameters_) { + stream << "(" << pair.key() << ")" + << ": Parameter containing: [" << pair.value().scalar_type() + << " of size " << pair.value().sizes() << "]"; + ; + stream << '\n'; + } + stream << ")"; + } + + /// push the a given parameter at the end of the list + void append(torch::Tensor&& param) { + bool requires_grad = param.requires_grad(); + register_parameter( + std::to_string(parameters_.size()), std::move(param), requires_grad); + } + + /// push the a given parameter at the end of the list + void append(const torch::Tensor& param) { + bool requires_grad = param.requires_grad(); + register_parameter( + std::to_string(parameters_.size()), param, requires_grad); + } + + /// push the a given parameter at the end of the list + /// And the key of the pair will be discarded, only the value + /// will be added into the `ParameterList` + void append(const OrderedDict::Item& pair) { + register_parameter( + std::to_string(parameters_.size()), + pair.value(), + pair.value().requires_grad()); + } + + /// extend parameters from a container to the end of the list + template + void extend(const Container& container) { + for (const auto& param : container) { + append(param); + } + } + + /// Returns an iterator to the start of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + Iterator begin() { + return parameters_.begin(); + } + + /// Returns a const iterator to the start of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + ConstIterator begin() const { + return parameters_.begin(); + } + + /// Returns an iterator to the end of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + Iterator end() { + return parameters_.end(); + } + + /// Returns a const iterator to the end of the ParameterList + /// the iterator returned will be type of `OrderedDict::Item` + ConstIterator end() const { + return parameters_.end(); + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + at::Tensor& at(size_t idx) { + TORCH_CHECK(idx < size(), "Index out of range"); + return parameters_[std::to_string(idx)]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + const at::Tensor& at(size_t idx) const { + TORCH_CHECK(idx < size(), "Index out of range"); + return parameters_[std::to_string(idx)]; + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + at::Tensor& operator[](size_t idx) { + return at(idx); + } + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `ParameterList`. Check contains(key) before + /// for a non-throwing way of access + const at::Tensor& operator[](size_t idx) const { + return at(idx); + } + + /// Return the size of the ParameterList + size_t size() const noexcept { + return parameters_.size(); + } + /// True if the ParameterList is empty + bool is_empty() const noexcept { + return parameters_.is_empty(); + } + + /// Overload the +=, so that two ParameterList could be incrementally added + template + Container& operator+=(const Container& other) { + extend(other); + return *this; + } + + private: + template + void push_back_var(Head&& head, Tail&&... tail) { + append(std::forward(head)); + // Recursively calls this method, until the parameter pack only thas this + // entry left. Then calls `push_back()` a final time (above). + push_back_var(std::forward(tail)...); + } + + /// The base case, when the list of modules is empty. + void push_back_var() {} +}; +TORCH_MODULE(ParameterList); +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h new file mode 100644 index 0000000000000000000000000000000000000000..1f11649575480212e2835179d4768419fa80a9dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -0,0 +1,387 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch::nn { + +/// A list of `Module`s that acts as a `Module` itself. +/// +/// A `Sequential` is fundamentally a list of `Module`s, each with a `forward()` +/// method. `Sequential` provides a `forward()` method of its own, which accepts +/// any input and forwards it to the first module it stores. It then "chains" +/// outputs to inputs sequentially for each subsequent module, finally returning +/// the output of the last module. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::Sequential seq( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// auto output = seq->forward(torch::ones(3)); +/// +/// \endrst +/// +/// This can conceptually be thought of as the following loop (using Python as +/// pseudocode): +/// +/// \rst +/// .. code-block:: python +/// +/// def forward(sequential, input): +/// for module in sequential: +/// input = module(input) +/// return input +/// +/// \endrst +/// +/// Why should you use `Sequential` instead of a simple `std::vector`? The value +/// a `Sequential` provides over manually calling a sequence of modules is that +/// it allows treating the whole container *as a single module*, such that +/// performing a transformation on the `Sequential` applies to each of the +/// modules it stores (which are each a registered submodule of the +/// `Sequential`). For example, calling +/// `.to(torch::kCUDA)` on a `Sequential` will move each module in the list to +/// CUDA memory. For example: +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::Sequential seq( +/// torch::nn::Linear(3, 4), +/// torch::nn::BatchNorm1d(4), +/// torch::nn::Dropout(0.5) +/// ); +/// +/// // Convert all modules to CUDA. +/// seq->to(torch::kCUDA); +/// +/// \endrst +/// +/// Finally, `Sequential` provides a lightweight container API, such as allowing +/// iteration over submodules, positional access, adding a new module after +/// construction via `push_back`, as well as joining two `Sequential`s via +/// `extend`. +/// +/// \rst +/// .. attention:: +/// One current limitation of `Sequential` is that all except the first module +/// must accept a single argument. If your modules need to take multiple +/// arguments, you should define them to take and return tuples. +/// \endrst +class SequentialImpl : public Cloneable { + public: + using Iterator = std::vector::iterator; + using ConstIterator = std::vector::const_iterator; + + SequentialImpl() = default; + + /// Constructs the `Sequential` from a variadic list of modules. + template + explicit SequentialImpl(Modules&&... modules) { + modules_.reserve(sizeof...(Modules)); + push_back(std::forward(modules)...); + } + + /// Constructs the `Sequential` from an `OrderedDict` of named `AnyModule`s. + explicit SequentialImpl( + torch::OrderedDict&& ordered_dict) { + modules_.reserve(ordered_dict.size()); + for (auto& item : ordered_dict) { + push_back(item.key(), std::move(item.value())); + } + } + + /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s. + /// It enables the following use case: + /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})` + explicit SequentialImpl(std::initializer_list named_modules) { + modules_.reserve(named_modules.size()); + for (const auto& named_module : named_modules) { + push_back(named_module.name(), named_module.module()); + } + } + + /// Special cloning function for `Sequential` because it does not use + /// `reset()`. + std::shared_ptr clone( + const std::optional& device = std::nullopt) const override { + auto clone = std::make_shared(); + for (const auto& module : modules_) { + clone->push_back(module.clone(device)); + } + return clone; + } + + /// `reset()` is empty for `Sequential`, since it does not have parameters of + /// its own. + void reset() override {} + + /// Pretty prints the `Sequential` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::Sequential"; + } + + /// Feeds `inputs` to the first module and then chains outputs to inputs, + /// returning the last output. + /// + /// Conceptually the following loop in Python: + /// + /// \rst + /// .. code-block:: python + /// + /// def forward(sequential, input): + /// for module in sequential: + /// input = module(input) + /// return input + /// + /// \endrst + /// + /// The return type is taken as the first template parameter. It defaults to + /// `Tensor`. If the last module in the `Sequential` returns another type `T`, + /// you should call `forward(inputs)` instead of just `forward(inputs)`: + /// + /// \rst + /// .. code-block:: cpp + /// + /// torch::Tensor tensor = sequential1->forward(inputs); + /// int integer = sequential2->forward(inputs); + /// float value = sequential3->forward(inputs); + /// + /// \endrst + template + ReturnType forward(InputTypes&&... inputs) { + TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential"); + + auto iterator = modules_.begin(); + auto input = iterator->any_forward(std::forward(inputs)...); + + for (++iterator; iterator != modules_.end(); ++iterator) { + input = iterator->any_forward(std::move(input)); + } + + // Check the return value and give a nice error message if the requested + // return type was incorrect. + if (auto* return_value = input.template try_get()) { + return std::move(*return_value); + } + TORCH_CHECK( + false, + "The type of the return value is ", + c10::demangle(input.type_info().name()), + ", but you asked for type ", + c10::demangle(typeid(ReturnType).name())); + } + + /// Adds a new (boxed) `Module` to the `Sequential` container. + template + void push_back(std::shared_ptr module_ptr) { + push_back(std::to_string(modules_.size()), std::move(module_ptr)); + } + + /// Adds a new named (boxed) `Module` to the `Sequential` container. + template + void push_back(std::string name, std::shared_ptr module_ptr) { + push_back(std::move(name), AnyModule(std::move(module_ptr))); + } + + /// Adds a new `Module` to the `Sequential` container, moving or copying it + /// into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. This means you can write + /// `Sequential(Module(3, 4))` instead of + /// `Sequential(std::make_shared(3, 4))`. + template > + void push_back(M&& module) { + push_back(std::to_string(modules_.size()), std::forward(module)); + } + + /// Adds a new named `Module` to the `Sequential` container, moving or copying + /// it into a `shared_ptr` internally. This method allows passing value types, + /// and letting the container deal with the boxing. + template > + void push_back(std::string name, M&& module) { + using Type = typename std::remove_reference_t; + push_back(std::move(name), std::make_shared(std::forward(module))); + } + + /// Unwraps the contained module of a `ModuleHolder` and adds it to the + /// `Sequential`. + template + void push_back(const ModuleHolder& module_holder) { + push_back(std::to_string(modules_.size()), module_holder); + } + + /// Unwraps the contained named module of a `ModuleHolder` and adds it to the + /// `Sequential`. + template + void push_back(std::string name, const ModuleHolder& module_holder) { + push_back(std::move(name), module_holder.ptr()); + } + + /// Iterates over the container and calls `push_back()` on each value. + template + void extend(const Container& container) { + for (const auto& module : container) { + push_back(module); + } + } + + /// Adds a type-erased `AnyModule` to the `Sequential`. + void push_back(AnyModule any_module) { + push_back(std::to_string(modules_.size()), std::move(any_module)); + } + + void push_back(std::string name, AnyModule any_module) { + modules_.push_back(std::move(any_module)); + const auto index = modules_.size() - 1; + register_module(std::move(name), modules_[index].ptr()); + } + + /// Returns an iterator to the start of the `Sequential`. + Iterator begin() { + return modules_.begin(); + } + + /// Returns a const iterator to the start of the `Sequential`. + ConstIterator begin() const { + return modules_.begin(); + } + + /// Returns an iterator to the end of the `Sequential`. + Iterator end() { + return modules_.end(); + } + + /// Returns a const iterator to the end of the `Sequential`. + ConstIterator end() const { + return modules_.end(); + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + T& at(size_t index) { + static_assert( + torch::detail::is_module::value, + "Can only call Sequential::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].get(); + } + + /// Attempts to return the module at the given index as the requested type. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + const T& at(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call Sequential::at with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].get(); + } + + /// Attempts to return a `std::shared_ptr` whose dynamic type is that of the + /// underlying module at the given index. Throws an exception if the index is + /// out of bounds. + std::shared_ptr ptr(size_t index) const { + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].ptr(); + } + + /// Attempts to return a `std::shared_ptr` whose type is the one provided. + /// Throws an exception if the index is out of bounds or the types do not + /// match. + template + std::shared_ptr ptr(size_t index) const { + static_assert( + torch::detail::is_module::value, + "Can only call Sequential::ptr with an nn::Module type"); + TORCH_CHECK(index < size(), "Index out of range"); + return modules_[index].ptr(); + } + + /// Like `ptr(index)`. + std::shared_ptr operator[](size_t index) const { + // This is the only method we can call without a type. + return ptr(index); + } + + /// The current size of the `Sequential` container. + size_t size() const noexcept { + return modules_.size(); + } + + /// True if there are no modules in the `Sequential`. + bool is_empty() const noexcept { + return size() == 0; + } + + private: + /// Takes a First *and* Second parameter, to avoid ambiguity when a parameter + /// pack has only one type, in which case the template would be preferred, + /// even if the other `push_back` functions are better fits (e.g. `unique_ptr` + /// -> `shared_ptr` overload). + /// NOTE: We explicitly avoid matching this template with + /// `push_back(std::string("name"), module)` or `push_back("name", module)`, + /// since they should be handled by their respective `push_back` functions. + template < + typename First, + typename Second, + typename... Rest, + typename = std::enable_if_t< + !std::is_same_v && + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + !std::is_same_v, std::decay_t>>> + void push_back(First&& first, Second&& second, Rest&&... rest) { + push_back(std::forward(first)); + // Recursively calls this method, until the parameter pack only thas this + // entry left. Then calls `push_back()` a final time (above). + push_back(std::forward(second), std::forward(rest)...); + } + + /// The base case, when the list of modules is empty. + void push_back() {} + + // Box the AnyModules to give Sequential reference semantics, like the rest of + // the API. Note that this is not required otherwise, this could just be a + // `vector`. + std::vector modules_; +}; + +/// A `ModuleHolder` subclass for `SequentialImpl`. +/// See the documentation for `SequentialImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +class Sequential : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + + Sequential() = default; + + /// Constructs the `Sequential` from a braced-init-list of named `AnyModule`s. + /// It enables the following use case: + /// `Sequential sequential({{"m1", M(1)}, {"m2", M(2)}})` + Sequential(std::initializer_list named_modules) + : ModuleHolder(std::make_shared(named_modules)) {} +}; +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/conv.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/conv.h new file mode 100644 index 0000000000000000000000000000000000000000..8c5f1f3e39182fb5a5b03356f863e1082d366fe2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/conv.h @@ -0,0 +1,448 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch::nn { + +/// Base class for all (dimension-specialized) convolution modules. +template +class ConvNdImpl : public torch::nn::Cloneable { + public: + explicit ConvNdImpl(detail::ConvNdOptions options_) + : options(std::move(options_)) { + ConvNdImpl::reset(); + } + + void reset() override { + TORCH_CHECK( + options.in_channels() > 0 && options.groups() > 0 && + options.out_channels() > 0, + "in_channels, groups and out_channels must be a positive integer."); + TORCH_CHECK( + options.in_channels() % options.groups() == 0, + "in_channels must be divisible by groups"); + TORCH_CHECK( + options.out_channels() % options.groups() == 0, + "out_channels must be divisible by groups"); + + std::visit( + c10::overloaded( + [&](enumtype::kValid) { + _reversed_padding_repeated_twice.resize(2 * D); + std::fill_n(_reversed_padding_repeated_twice.begin(), 2 * D, 0); + }, + [&](enumtype::kSame) { + for (const auto i : c10::irange(D)) { + const auto stride = (*options.stride())[i]; + TORCH_CHECK( + stride == 1, + "padding='same' is not supported for strided convolutions"); + } + + _reversed_padding_repeated_twice.resize(2 * D); + for (const auto i : c10::irange(D)) { + const auto dilation = (*options.dilation())[i]; + const auto kernel_size = (*options.kernel_size())[i]; + const auto total_padding = dilation * (kernel_size - 1); + auto left_pad = total_padding / 2; + auto right_pad = total_padding - left_pad; + _reversed_padding_repeated_twice[2 * i] = left_pad; + _reversed_padding_repeated_twice[2 * i + 1] = right_pad; + } + }, + [&](const ExpandingArray& pad) { + _reversed_padding_repeated_twice = + torch::nn::modules::utils::_reverse_repeat_vector(pad, 2); + }), + options.padding()); + + if (options.transposed()) { + std::vector weight_sizes = { + options.in_channels(), options.out_channels() / options.groups()}; + weight_sizes.insert( + weight_sizes.end(), + (*options.kernel_size()).begin(), + (*options.kernel_size()).end()); + weight = this->register_parameter("weight", torch::empty(weight_sizes)); + } else { + std::vector weight_sizes = { + options.out_channels(), options.in_channels() / options.groups()}; + weight_sizes.insert( + weight_sizes.end(), + (*options.kernel_size()).begin(), + (*options.kernel_size()).end()); + weight = this->register_parameter("weight", torch::empty(weight_sizes)); + } + + if (options.bias()) { + bias = this->register_parameter( + "bias", torch::empty({options.out_channels()})); + } else { + this->register_parameter("bias", Tensor(), /*requires_grad=*/false); + } + + reset_parameters(); + } + + void reset_parameters() { + init::kaiming_uniform_( + weight, + /*a=*/std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers) + + if (bias.defined()) { + auto [fan_in, fan_out] = init::_calculate_fan_in_and_fan_out(weight); + auto bound = 1 / std::sqrt(fan_in); + init::uniform_(bias, -bound, bound); + } + } + + /// Pretty prints the `Conv{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::Conv" << D << "d" + << "(" << options.in_channels() << ", " << options.out_channels() + << ", kernel_size=" << options.kernel_size() + << ", stride=" << options.stride(); + std::visit( + c10::overloaded( + [&](enumtype::kValid) { stream << ", padding='valid'"; }, + [&](enumtype::kSame) { stream << ", padding='same'"; }, + [&](const ExpandingArray& pad) { + if (*pad != *ExpandingArray(0)) { + stream << ", padding=" << pad; + } + }), + options.padding()); + if (*options.dilation() != *ExpandingArray(1)) { + stream << ", dilation=" << options.dilation(); + } + if (*options.output_padding() != *ExpandingArray(0)) { + stream << ", output_padding=" << options.output_padding(); + } + if (options.groups() != 1) { + stream << ", groups=" << options.groups(); + } + if (!options.bias()) { + stream << ", bias=" << std::boolalpha << false; + } + if (!std::get_if(&options.padding_mode())) { + stream << ", padding_mode=" + << enumtype::get_enum_name(options.padding_mode()); + } + stream << ")"; + } + + /// The options with which this `Module` was constructed. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + detail::ConvNdOptions options; + + /// The learned kernel (or "weight"). + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + Tensor weight; + + /// The learned bias. Only defined if the `bias` option was true. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + Tensor bias; + + protected: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector _reversed_padding_repeated_twice; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies convolution over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv1d to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::Conv1dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false)); +/// ``` +class TORCH_API Conv1dImpl : public ConvNdImpl<1, Conv1dImpl> { + public: + Conv1dImpl( + int64_t input_channels, + int64_t output_channels, + ExpandingArray<1> kernel_size) + : Conv1dImpl( + Conv1dOptions(input_channels, output_channels, kernel_size)) {} + explicit Conv1dImpl(Conv1dOptions options_); + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `Conv1dImpl`. +/// See the documentation for `Conv1dImpl` class to learn what methods it +/// provides, and examples of how to use `Conv1d` with +/// `torch::nn::Conv1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Conv1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies convolution over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv2d to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::Conv2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false)); +/// ``` +class TORCH_API Conv2dImpl : public ConvNdImpl<2, Conv2dImpl> { + public: + Conv2dImpl( + int64_t input_channels, + int64_t output_channels, + ExpandingArray<2> kernel_size) + : Conv2dImpl( + Conv2dOptions(input_channels, output_channels, kernel_size)) {} + explicit Conv2dImpl(Conv2dOptions options_); + Tensor forward(const Tensor& input); + + protected: + Tensor _conv_forward(const Tensor& input, const Tensor& weight); +}; + +/// A `ModuleHolder` subclass for `Conv2dImpl`. +/// See the documentation for `Conv2dImpl` class to learn what methods it +/// provides, and examples of how to use `Conv2d` with +/// `torch::nn::Conv2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Conv2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies convolution over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv3d to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::Conv3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false)); +/// ``` +class TORCH_API Conv3dImpl : public ConvNdImpl<3, Conv3dImpl> { + public: + Conv3dImpl( + int64_t input_channels, + int64_t output_channels, + ExpandingArray<3> kernel_size) + : Conv3dImpl( + Conv3dOptions(input_channels, output_channels, kernel_size)) {} + explicit Conv3dImpl(Conv3dOptions options_); + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `Conv3dImpl`. +/// See the documentation for `Conv3dImpl` class to learn what methods it +/// provides, and examples of how to use `Conv3d` with +/// `torch::nn::Conv3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Conv3d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Base class for all (dimension-specialized) convolution transpose modules. +template +class ConvTransposeNdImpl : public ConvNdImpl { + public: + using torch::nn::ConvNdImpl::ConvNdImpl; + explicit ConvTransposeNdImpl(detail::ConvNdOptions options_) + : ConvNdImpl(options_) { + TORCH_INTERNAL_ASSERT( + std::holds_alternative>(this->options.padding()), + "ConvTranspose padding cannot be a string"); + } + + /// Pretty prints the `ConvTranspose{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::ConvTranspose" << D << "d" + << "(" << this->options.in_channels() << ", " + << this->options.out_channels() + << ", kernel_size=" << this->options.kernel_size() + << ", stride=" << this->options.stride(); + const auto& pad = padding(); + if (*pad != *ExpandingArray(0)) { + stream << ", padding=" << pad; + } + if (*this->options.dilation() != *ExpandingArray(1)) { + stream << ", dilation=" << this->options.dilation(); + } + if (*this->options.output_padding() != *ExpandingArray(0)) { + stream << ", output_padding=" << this->options.output_padding(); + } + if (this->options.groups() != 1) { + stream << ", groups=" << this->options.groups(); + } + if (!this->options.bias()) { + stream << ", bias=" << std::boolalpha << false; + } + if (!std::get_if(&this->options.padding_mode())) { + stream << ", padding_mode=" + << enumtype::get_enum_name(this->options.padding_mode()); + } + stream << ")"; + } + + protected: + const ExpandingArray& padding() const { + return std::get>(this->options.padding()); + } + + std::vector _output_padding( + const Tensor& input, + const std::optional& output_size, + const ExpandingArray& stride, + const ExpandingArray& padding, + const ExpandingArray& kernel_size); +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the ConvTranspose1d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose1d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ConvTranspose1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, +/// 3).stride(1).bias(false)); +/// ``` +class TORCH_API ConvTranspose1dImpl + : public ConvTransposeNdImpl<1, ConvTranspose1dImpl> { + public: + ConvTranspose1dImpl( + int64_t input_channels, + int64_t output_channels, + ExpandingArray<1> kernel_size) + : ConvTranspose1dImpl(ConvTranspose1dOptions( + input_channels, + output_channels, + kernel_size)) {} + explicit ConvTranspose1dImpl(ConvTranspose1dOptions options_); + Tensor forward( + const Tensor& input, + const std::optional& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) +}; + +/// A `ModuleHolder` subclass for `ConvTranspose1dImpl`. +/// See the documentation for `ConvTranspose1dImpl` class to learn what methods +/// it provides, and examples of how to use `ConvTranspose1d` with +/// `torch::nn::ConvTranspose1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ConvTranspose1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the ConvTranspose2d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ConvTranspose2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, +/// 3).stride(1).bias(false)); +/// ``` +class TORCH_API ConvTranspose2dImpl + : public ConvTransposeNdImpl<2, ConvTranspose2dImpl> { + public: + ConvTranspose2dImpl( + int64_t input_channels, + int64_t output_channels, + ExpandingArray<2> kernel_size) + : ConvTranspose2dImpl(ConvTranspose2dOptions( + input_channels, + output_channels, + kernel_size)) {} + explicit ConvTranspose2dImpl(ConvTranspose2dOptions options_); + Tensor forward( + const Tensor& input, + const std::optional& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) +}; + +/// A `ModuleHolder` subclass for `ConvTranspose2dImpl`. +/// See the documentation for `ConvTranspose2dImpl` class to learn what methods +/// it provides, and examples of how to use `ConvTranspose2d` with +/// `torch::nn::ConvTranspose2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ConvTranspose2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the ConvTranspose3d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ConvTranspose3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, +/// 2).stride(1).bias(false)); +/// ``` +class TORCH_API ConvTranspose3dImpl + : public ConvTransposeNdImpl<3, ConvTranspose3dImpl> { + public: + ConvTranspose3dImpl( + int64_t input_channels, + int64_t output_channels, + ExpandingArray<3> kernel_size) + : ConvTranspose3dImpl(ConvTranspose3dOptions( + input_channels, + output_channels, + kernel_size)) {} + explicit ConvTranspose3dImpl(ConvTranspose3dOptions options_); + Tensor forward( + const Tensor& input, + const std::optional& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) +}; + +/// A `ModuleHolder` subclass for `ConvTranspose3dImpl`. +/// See the documentation for `ConvTranspose3dImpl` class to learn what methods +/// it provides, and examples of how to use `ConvTranspose3d` with +/// `torch::nn::ConvTranspose3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ConvTranspose3d); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/distance.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/distance.h new file mode 100644 index 0000000000000000000000000000000000000000..7166ba15d182154a879cde7522465ec8909a8235 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/distance.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch::nn { + +/// Returns the cosine similarity between :math:`x_1` and :math:`x_2`, computed +/// along `dim`. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.CosineSimilarity to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CosineSimilarityOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CosineSimilarity model(CosineSimilarityOptions().dim(0).eps(0.5)); +/// ``` +class TORCH_API CosineSimilarityImpl : public Cloneable { + public: + explicit CosineSimilarityImpl(const CosineSimilarityOptions& options_ = {}); + + void reset() override; + + /// Pretty prints the `CosineSimilarity` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input1, const Tensor& input2); + + /// The options with which this `Module` was constructed. + CosineSimilarityOptions options; +}; + +/// A `ModuleHolder` subclass for `CosineSimilarityImpl`. +/// See the documentation for `CosineSimilarityImpl` class to learn what methods +/// it provides, and examples of how to use `CosineSimilarity` with +/// `torch::nn::CosineSimilarityOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(CosineSimilarity); + +// ============================================================================ + +/// Returns the batchwise pairwise distance between vectors :math:`v_1`, +/// :math:`v_2` using the p-norm. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.PairwiseDistance to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::PairwiseDistanceOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// PairwiseDistance +/// model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true)); +/// ``` +class TORCH_API PairwiseDistanceImpl : public Cloneable { + public: + explicit PairwiseDistanceImpl(const PairwiseDistanceOptions& options_ = {}); + + void reset() override; + + /// Pretty prints the `PairwiseDistance` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input1, const Tensor& input2); + + /// The options with which this `Module` was constructed. + PairwiseDistanceOptions options; +}; + +/// A `ModuleHolder` subclass for `PairwiseDistanceImpl`. +/// See the documentation for `PairwiseDistanceImpl` class to learn what methods +/// it provides, and examples of how to use `PairwiseDistance` with +/// `torch::nn::PairwiseDistanceOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(PairwiseDistance); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h new file mode 100644 index 0000000000000000000000000000000000000000..e98b92499fb28ddf662a5dda30d071cb75a3a1e6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h @@ -0,0 +1,184 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch::nn { + +namespace detail { + +template +class _DropoutNd : public torch::nn::Cloneable { + public: + _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)) {} + + explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) { + _DropoutNd::reset(); + } + + void reset() override { + TORCH_CHECK( + options.p() >= 0. && options.p() <= 1., + "dropout probability has to be between 0 and 1, but got ", + options.p()); + } + + /// The options with which this `Module` was constructed. + DropoutOptions options; +}; + +} // namespace detail + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies dropout over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::DropoutOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Dropout model(DropoutOptions().p(0.42).inplace(true)); +/// ``` +class TORCH_API DropoutImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(Tensor input); + + /// Pretty prints the `Dropout` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `DropoutImpl`. +/// See the documentation for `DropoutImpl` class to learn what methods it +/// provides, and examples of how to use `Dropout` with +/// `torch::nn::DropoutOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Dropout); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies dropout over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::Dropout2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true)); +/// ``` +class TORCH_API Dropout2dImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(Tensor input); + + /// Pretty prints the `Dropout2d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `Dropout2dImpl`. +/// See the documentation for `Dropout2dImpl` class to learn what methods it +/// provides, and examples of how to use `Dropout2d` with +/// `torch::nn::Dropout2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Dropout2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies dropout over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::Dropout3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true)); +/// ``` +class TORCH_API Dropout3dImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(Tensor input); + + /// Pretty prints the `Dropout3d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `Dropout3dImpl`. +/// See the documentation for `Dropout3dImpl` class to learn what methods it +/// provides, and examples of how to use `Dropout3d` with +/// `torch::nn::Dropout3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Dropout3d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Alpha Dropout over the input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AlphaDropout to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true)); +/// ``` +class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `AlphaDropout` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `AlphaDropoutImpl`. +/// See the documentation for `AlphaDropoutImpl` class to learn what methods it +/// provides, and examples of how to use `AlphaDropout` with +/// `torch::nn::AlphaDropoutOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(AlphaDropout); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true)); +/// ``` +class TORCH_API FeatureAlphaDropoutImpl + : public detail::_DropoutNd { + public: + using detail::_DropoutNd::_DropoutNd; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `FeatureAlphaDropout` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +/// A `ModuleHolder` subclass for `FeatureAlphaDropoutImpl`. +/// See the documentation for `FeatureAlphaDropoutImpl` class to learn what +/// methods it provides, and examples of how to use `FeatureAlphaDropout` with +/// `torch::nn::FeatureAlphaDropoutOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(FeatureAlphaDropout); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/embedding.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/embedding.h new file mode 100644 index 0000000000000000000000000000000000000000..f8af433bcc4c10fb530574f56248d009f86c1bc9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/embedding.h @@ -0,0 +1,165 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Performs a lookup in a fixed size embedding table. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Embedding to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::EmbeddingOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Embedding model(EmbeddingOptions(10, +/// 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// ``` +class TORCH_API EmbeddingImpl : public torch::nn::Cloneable { + public: + EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim) + : EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {} + explicit EmbeddingImpl(EmbeddingOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `Embedding` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Performs a lookup on the embedding table stored in `weight` using the + /// `indices` supplied and returns the result. + Tensor forward(const Tensor& indices); + + /// The `Options` used to configure this `Embedding` module. + /// Changes to `EmbeddingOptions` *after construction* have no effect. + EmbeddingOptions options; + + /// The embedding table. + Tensor weight; +}; + +/// A `ModuleHolder` subclass for `EmbeddingImpl`. +/// See the documentation for `EmbeddingImpl` class to learn what methods it +/// provides, and examples of how to use `Embedding` with +/// `torch::nn::EmbeddingOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +class Embedding : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + + /// See the documentation for `torch::nn::EmbeddingFromPretrainedOptions` + /// class to learn what optional arguments are supported for this function. + static Embedding from_pretrained( + const torch::Tensor& embeddings, + const EmbeddingFromPretrainedOptions& options = {}) { + TORCH_CHECK( + embeddings.dim() == 2, + "Embeddings parameter is expected to be 2-dimensional"); + + auto rows = embeddings.size(0); + auto cols = embeddings.size(1); + + Embedding embedding(EmbeddingOptions(rows, cols) + ._weight(embeddings) + .padding_idx(options.padding_idx()) + .max_norm(options.max_norm()) + .norm_type(options.norm_type()) + .scale_grad_by_freq(options.scale_grad_by_freq()) + .sparse(options.sparse())); + embedding->weight.set_requires_grad(!options.freeze()); + return embedding; + } +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ EmbeddingBag +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Computes sums or means of 'bags' of embeddings, without instantiating the +/// intermediate embeddings. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.EmbeddingBag to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::EmbeddingBagOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// EmbeddingBag model(EmbeddingBagOptions(10, +/// 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum).padding_idx(1)); +/// ``` +class TORCH_API EmbeddingBagImpl + : public torch::nn::Cloneable { + public: + EmbeddingBagImpl(int64_t num_embeddings, int64_t embedding_dim) + : EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {} + explicit EmbeddingBagImpl(EmbeddingBagOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `EmbeddingBag` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The `Options` used to configure this `EmbeddingBag` module. + EmbeddingBagOptions options; + /// The embedding table. + Tensor weight; + + Tensor forward( + const Tensor& input, + const Tensor& offsets = {}, + const Tensor& per_sample_weights = {}); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) +}; + +/// A `ModuleHolder` subclass for `EmbeddingBagImpl`. +/// See the documentation for `EmbeddingBagImpl` class to learn what methods it +/// provides, and examples of how to use `EmbeddingBag` with +/// `torch::nn::EmbeddingBagOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +class EmbeddingBag : public torch::nn::ModuleHolder { + public: + using torch::nn::ModuleHolder::ModuleHolder; + + /// See the documentation for `torch::nn::EmbeddingBagFromPretrainedOptions` + /// class to learn what optional arguments are supported for this function. + static EmbeddingBag from_pretrained( + const torch::Tensor& embeddings, + const EmbeddingBagFromPretrainedOptions& options = {}) { + TORCH_CHECK( + embeddings.dim() == 2, + "Embeddings parameter is expected to be 2-dimensional"); + + auto rows = embeddings.size(0); + auto cols = embeddings.size(1); + + EmbeddingBag embeddingbag( + EmbeddingBagOptions(rows, cols) + ._weight(embeddings) + .max_norm(options.max_norm()) + .norm_type(options.norm_type()) + .scale_grad_by_freq(options.scale_grad_by_freq()) + .mode(options.mode()) + .sparse(options.sparse()) + .padding_idx(options.padding_idx())); + embeddingbag->weight.set_requires_grad(!options.freeze()); + return embeddingbag; + } +}; +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h new file mode 100644 index 0000000000000000000000000000000000000000..4ad49f191fbbac76b0695603615b4c5a2cf41ca1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/fold.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::nn { + +/// Applies fold over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Fold to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::FoldOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, +/// 1}).stride(2)); +/// ``` +class TORCH_API FoldImpl : public torch::nn::Cloneable { + public: + FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) + : FoldImpl(FoldOptions(output_size, kernel_size)) {} + explicit FoldImpl(const FoldOptions& options_); + + void reset() override; + + /// Pretty prints the `Fold` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this `Module` was constructed. + FoldOptions options; +}; + +/// A `ModuleHolder` subclass for `FoldImpl`. +/// See the documentation for `FoldImpl` class to learn what methods it +/// provides, and examples of how to use `Fold` with `torch::nn::FoldOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Fold); + +// ============================================================================ + +/// Applies unfold over a 4-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Unfold to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::UnfoldOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2)); +/// ``` +class TORCH_API UnfoldImpl : public Cloneable { + public: + UnfoldImpl(ExpandingArray<2> kernel_size) + : UnfoldImpl(UnfoldOptions(kernel_size)) {} + explicit UnfoldImpl(const UnfoldOptions& options_); + + void reset() override; + + /// Pretty prints the `Unfold` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this `Module` was constructed. + UnfoldOptions options; +}; + +/// A `ModuleHolder` subclass for `UnfoldImpl`. +/// See the documentation for `UnfoldImpl` class to learn what methods it +/// provides, and examples of how to use `Unfold` with +/// `torch::nn::UnfoldOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Unfold); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h new file mode 100644 index 0000000000000000000000000000000000000000..228f181715fc775f2d4d44dd2f69c747bee7fd78 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h @@ -0,0 +1,153 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn { + +/// Base class for all (dimension-specialized) instance norm modules +template +// NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility) +class InstanceNormImpl + : public torch::nn::NormImplBase { + private: + inline Tensor apply_instance_norm(const Tensor& input) { + return torch::nn::functional::detail::instance_norm( + input, + this->running_mean, + this->running_var, + this->weight, + this->bias, + this->is_training() || !this->options.track_running_stats(), + this->options.momentum(), + this->options.eps()); + } + + inline Tensor handle_no_batch_input(const Tensor& input) { + return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0); + } + + public: + using torch::nn::NormImplBase::NormImplBase; + + Tensor forward(const Tensor& input) { + this->_check_input_dim(input); + + // For InstanceNorm1D, 2D is unbatched and 3D is batched + // For InstanceNorm2D, 3D is unbatched and 4D is batched + // For InstanceNorm3D, 4D is unbatched and 5D is batched + // check if input does not have a batch-dim + if (input.dim() == D + 1) { + return this->handle_no_batch_input(input); + } + + return this->apply_instance_norm(input); + } + + /// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d(" + << this->options.num_features() << ", " + << "eps=" << this->options.eps() << ", " + << "momentum=" << this->options.momentum() << ", " + << "affine=" << this->options.affine() << ", " + << "track_running_stats=" << this->options.track_running_stats() + << ")"; + } +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the InstanceNorm1d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::InstanceNorm1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// InstanceNorm1d +/// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API InstanceNorm1dImpl + : public InstanceNormImpl<1, InstanceNorm1dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl; +}; + +/// A `ModuleHolder` subclass for `InstanceNorm1dImpl`. +/// See the documentation for `InstanceNorm1dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm1d` with +/// `torch::nn::InstanceNorm1dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(InstanceNorm1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the InstanceNorm2d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// InstanceNorm2d +/// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API InstanceNorm2dImpl + : public InstanceNormImpl<2, InstanceNorm2dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl; +}; + +/// A `ModuleHolder` subclass for `InstanceNorm2dImpl`. +/// See the documentation for `InstanceNorm2dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm2d` with +/// `torch::nn::InstanceNorm2dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(InstanceNorm2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the InstanceNorm3d function. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// InstanceNorm3d +/// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +class TORCH_API InstanceNorm3dImpl + : public InstanceNormImpl<3, InstanceNorm3dImpl> { + protected: + void _check_input_dim(const Tensor& input) override; + + public: + using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl; +}; + +/// A `ModuleHolder` subclass for `InstanceNorm3dImpl`. +/// See the documentation for `InstanceNorm3dImpl` class to learn what methods +/// it provides, and examples of how to use `InstanceNorm3d` with +/// `torch::nn::InstanceNorm3dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(InstanceNorm3d); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/linear.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/linear.h new file mode 100644 index 0000000000000000000000000000000000000000..cb54396837840d1fa650500ab66a160a96bc73fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/linear.h @@ -0,0 +1,214 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Identity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A placeholder identity operator that is argument-insensitive. +/// See https://pytorch.org/docs/main/generated/torch.nn.Identity.html to +/// learn about the exact behavior of this module. +class TORCH_API IdentityImpl : public Cloneable { + public: + void reset() override; + + /// Pretty prints the `Identity` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `IdentityImpl`. +/// See the documentation for `IdentityImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Identity); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Linear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies a linear transformation with optional bias. +/// See https://pytorch.org/docs/main/generated/torch.nn.Linear.html to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LinearOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Linear model(LinearOptions(5, 2).bias(false)); +/// ``` +class TORCH_API LinearImpl : public Cloneable { + public: + LinearImpl(int64_t in_features, int64_t out_features) + : LinearImpl(LinearOptions(in_features, out_features)) {} + explicit LinearImpl(const LinearOptions& options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `Linear` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Transforms the `input` tensor by multiplying with the `weight` and + /// optionally adding the `bias`, if `with_bias` is true in the options. + Tensor forward(const Tensor& input); + + /// The options used to configure this module. + LinearOptions options; + + /// The learned weight. + Tensor weight; + + /// The learned bias. If `bias` is false in the `options`, this tensor is + /// undefined. + Tensor bias; +}; + +/// A `ModuleHolder` subclass for `LinearImpl`. +/// See the documentation for `LinearImpl` class to learn what methods it +/// provides, and examples of how to use `Linear` with +/// `torch::nn::LinearOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Linear); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Flatten ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A placeholder for Flatten operator +/// See https://pytorch.org/docs/main/generated/torch.nn.Flatten.html to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::FlattenOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Flatten model(FlattenOptions().start_dim(2).end_dim(4)); +/// ``` +class TORCH_API FlattenImpl : public Cloneable { + public: + explicit FlattenImpl(const FlattenOptions& options_ = {}); + + void reset() override; + + /// Pretty prints the `Flatten` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Applies a flatten transform on the `input`. + Tensor forward(const Tensor& input); + + /// The options used to configure this module. + FlattenOptions options; +}; + +/// A `ModuleHolder` subclass for `FlattenImpl`. +/// See the documentation for `FlattenImpl` class to learn what methods it +/// provides, and examples of how to use `Flatten` with +/// `torch::nn::FlattenOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Flatten); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unflatten +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A placeholder for unflatten operator +/// See https://pytorch.org/docs/main/generated/torch.nn.Unflatten.html to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::UnflattenOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Unflatten model(UnflattenOptions(0, {2, 2})); +/// Unflatten model(UnflattenOptions("B", {{"B1", 2}, {"B2", 2}})); +/// ``` +class TORCH_API UnflattenImpl : public Cloneable { + public: + UnflattenImpl(int64_t dim, std::vector sizes) + : UnflattenImpl(UnflattenOptions(dim, std::move(sizes))) {} + UnflattenImpl(std::string dimname, UnflattenOptions::namedshape_t namedshape) + : UnflattenImpl( + UnflattenOptions(std::move(dimname), std::move(namedshape))) {} + explicit UnflattenImpl(UnflattenOptions options_); + + void reset() override; + + /// Pretty prints the `Unflatten` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Applies an unflatten transform on the `input`. + Tensor forward(const Tensor& input); + + /// The options used to configure this module. + UnflattenOptions options; +}; + +/// A `ModuleHolder` subclass for `UnflattenImpl`. +/// See the documentation for `UnflattenImpl` class to learn what methods it +/// provides, and examples of how to use `Unflatten` with +/// `torch::nn::UnflattenOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Unflatten); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Bilinear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies a billinear transformation with optional bias. +/// See https://pytorch.org/docs/main/generated/torch.nn.Bilinear.html to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BilinearOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Bilinear model(BilinearOptions(3, 2, 4).bias(false)); +/// ``` +class TORCH_API BilinearImpl : public Cloneable { + public: + BilinearImpl(int64_t in1_features, int64_t in2_features, int64_t out_features) + : BilinearImpl( + BilinearOptions(in1_features, in2_features, out_features)) {} + explicit BilinearImpl(const BilinearOptions& options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `Bilinear` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Applies a bilinear transform on the `input1` and `input2` tensor by + /// multiplying with the `weight` and optionally adding the `bias`, if + /// `with_bias` is true in the options. + Tensor forward(const Tensor& input1, const Tensor& input2); + + /// The options used to configure this module. + BilinearOptions options; + + /// The learned weight. + Tensor weight; + + /// The learned bias. If `with_bias` is false in the `options`, this tensor is + /// undefined. + Tensor bias; +}; + +/// A `ModuleHolder` subclass for `BilinearImpl`. +/// See the documentation for `BilinearImpl` class to learn what methods it +/// provides, and examples of how to use `Bilinear` with +/// `torch::nn::BilinearOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Bilinear); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h new file mode 100644 index 0000000000000000000000000000000000000000..52be4f612b59fb39429bce601f39024f4c399561 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/loss.h @@ -0,0 +1,803 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ L1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the mean absolute error (MAE) between each +/// element in the input : math :`x` and target : `y`. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.L1Loss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::L1LossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// L1Loss model(L1LossOptions(torch::kNone)); +/// ``` +struct TORCH_API L1LossImpl : Cloneable { + explicit L1LossImpl(L1LossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `L1Loss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + L1LossOptions options; +}; + +/// A `ModuleHolder` subclass for `L1LossImpl`. +/// See the documentation for `L1LossImpl` class to learn what methods it +/// provides, and examples of how to use `L1Loss` with +/// `torch::nn::L1LossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(L1Loss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ KLDivLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The Kullback-Leibler divergence loss measure +/// See https://pytorch.org/docs/main/nn.html#torch.nn.KLDivLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::KLDivLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// KLDivLoss model(KLDivLossOptions().reduction(torch::kNone)); +/// ``` +struct TORCH_API KLDivLossImpl : Cloneable { + explicit KLDivLossImpl(KLDivLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `KLDivLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + KLDivLossOptions options; +}; + +/// A `ModuleHolder` subclass for `KLDivLossImpl`. +/// See the documentation for `KLDivLossImpl` class to learn what methods it +/// provides, and examples of how to use `KLDivLoss` with +/// `torch::nn::KLDivLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(KLDivLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MSELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the mean squared error (squared L2 norm) +/// between each element in the input :math:`x` and target :math:`y`. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MSELoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MSELossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MSELoss model(MSELossOptions(torch::kNone)); +/// ``` +struct TORCH_API MSELossImpl : Cloneable { + explicit MSELossImpl(MSELossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `MSELoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MSELossOptions options; +}; + +/// A `ModuleHolder` subclass for `MSELossImpl`. +/// See the documentation for `MSELossImpl` class to learn what methods it +/// provides, and examples of how to use `MSELoss` with +/// `torch::nn::MSELossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MSELoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCELoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the Binary Cross Entropy +/// between the target and the output. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.BCELoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BCELossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API BCELossImpl : Cloneable { + explicit BCELossImpl(BCELossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `BCELoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + BCELossOptions options; +}; + +/// A `ModuleHolder` subclass for `BCELossImpl`. +/// See the documentation for `BCELossImpl` class to learn what methods it +/// provides, and examples of how to use `BCELoss` with +/// `torch::nn::BCELossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(BCELoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HingeEmbeddingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the loss given an input tensor :math:`x` +/// and a labels tensor :math:`y` (containing 1 or -1). +/// See https://pytorch.org/docs/main/nn.html#torch.nn.HingeEmbeddingLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// HingeEmbeddingLoss +/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone)); +/// ``` +struct TORCH_API HingeEmbeddingLossImpl : Cloneable { + explicit HingeEmbeddingLossImpl(HingeEmbeddingLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `HingeEmbeddingLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + HingeEmbeddingLossOptions options; +}; + +/// A `ModuleHolder` subclass for `HingeEmbeddingLossImpl`. +/// See the documentation for `HingeEmbeddingLossImpl` class to learn what +/// methods it provides, and examples of how to use `HingeEmbeddingLoss` with +/// `torch::nn::HingeEmbeddingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(HingeEmbeddingLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a multi-class classification hinge +/// loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) +/// and output :math:`y` (which is a 1D tensor of target class indices, :math:`0 +/// \leq y \leq \text{x.size}(1)-1`). See +/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiMarginLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight)); +/// ``` +struct TORCH_API MultiMarginLossImpl : public Cloneable { + explicit MultiMarginLossImpl(MultiMarginLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `MultiMarginLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MultiMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MultiMarginLossImpl`. +/// See the documentation for `MultiMarginLossImpl` class to learn what methods +/// it provides, and examples of how to use `MultiMarginLoss` with +/// `torch::nn::MultiMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MultiMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CosineEmbeddingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the loss given input tensors +/// `input1`, `input2`, and a `Tensor` label `target` with values 1 or +/// -1. This is used for measuring whether two inputs are similar or +/// dissimilar, using the cosine distance, and is typically used for learning +/// nonlinear embeddings or semi-supervised learning. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.CosineEmbeddingLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5)); +/// ``` +struct TORCH_API CosineEmbeddingLossImpl + : public Cloneable { + explicit CosineEmbeddingLossImpl(CosineEmbeddingLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `CosineEmbeddingLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& input1, + const Tensor& input2, + const Tensor& target); + + /// The options with which this `Module` was constructed. + CosineEmbeddingLossOptions options; +}; + +/// A `ModuleHolder` subclass for `CosineEmbeddingLossImpl`. +/// See the documentation for `CosineEmbeddingLossImpl` class to learn what +/// methods it provides, and examples of how to use `CosineEmbeddingLoss` with +/// `torch::nn::CosineEmbeddingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(CosineEmbeddingLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SmoothL1Loss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that uses a squared term if the absolute +/// element-wise error falls below beta and an L1 term otherwise. +/// It is less sensitive to outliers than the `MSELoss` and in some cases +/// prevents exploding gradients (e.g. see the paper `Fast R-CNN` by Ross +/// Girshick). See https://pytorch.org/docs/main/nn.html#torch.nn.SmoothL1Loss +/// to learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SmoothL1LossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5)); +/// ``` +struct TORCH_API SmoothL1LossImpl : public Cloneable { + explicit SmoothL1LossImpl(SmoothL1LossOptions options = {}); + + void reset() override; + + /// Pretty prints the `L1Loss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + SmoothL1LossOptions options; +}; + +/// A `ModuleHolder` subclass for `SmoothL1LossImpl`. +/// See the documentation for `SmoothL1LossImpl` class to learn what methods it +/// provides, and examples of how to use `SmoothL1Loss` with +/// `torch::nn::SmoothL1LossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(SmoothL1Loss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HuberLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that uses a squared term if the absolute +/// element-wise error falls below delta and a delta-scaled L1 term otherwise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.HuberLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::HuberLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5)); +/// ``` +struct TORCH_API HuberLossImpl : public Cloneable { + explicit HuberLossImpl(HuberLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `HuberLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + HuberLossOptions options; +}; + +/// A `ModuleHolder` subclass for `HuberLossImpl`. +/// See the documentation for `HuberLossImpl` class to learn what methods it +/// provides, and examples of how to use `HuberLoss` with +/// `torch::nn::HuberLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(HuberLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a multi-class multi-classification +/// hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch +/// `Tensor`) and output :math:`y` (which is a 2D `Tensor` of target class +/// indices). See +/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelMarginLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone)); +/// ``` +struct TORCH_API MultiLabelMarginLossImpl + : public Cloneable { + explicit MultiLabelMarginLossImpl(MultiLabelMarginLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `L1Loss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MultiLabelMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MultiLabelMarginLossImpl`. +/// See the documentation for `MultiLabelMarginLossImpl` class to learn what +/// methods it provides, and examples of how to use `MultiLabelMarginLoss` with +/// `torch::nn::MultiLabelMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MultiLabelMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SoftMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a two-class classification +/// logistic loss between input tensor :math:`x` and target tensor :math:`y` +/// (containing 1 or -1). +/// See https://pytorch.org/docs/main/nn.html#torch.nn.SoftMarginLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::SoftMarginLossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone)); +/// ``` +struct TORCH_API SoftMarginLossImpl : public Cloneable { + explicit SoftMarginLossImpl(SoftMarginLossOptions options_ = {}); + + /// Pretty prints the `SoftMarginLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + void reset() override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + SoftMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `SoftMarginLossImpl`. +/// See the documentation for `SoftMarginLossImpl` class to learn what methods +/// it provides, and examples of how to use `SoftMarginLoss` with +/// `torch::nn::SoftMarginLossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(SoftMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiLabelSoftMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that optimizes a multi-label one-versus-all +/// loss based on max-entropy, between input :math:`x` and target :math:`y` of +/// size :math:`(N, C)`. See +/// https://pytorch.org/docs/main/nn.html#torch.nn.MultiLabelSoftMarginLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class +/// to learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MultiLabelSoftMarginLoss +/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API MultiLabelSoftMarginLossImpl + : public Cloneable { + explicit MultiLabelSoftMarginLossImpl( + MultiLabelSoftMarginLossOptions options_ = {}); + + /// Pretty prints the `MultiLabelSoftMarginLoss` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override; + + void reset() override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + MultiLabelSoftMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MultiLabelSoftMarginLossImpl`. +/// See the documentation for `MultiLabelSoftMarginLossImpl` class to learn what +/// methods it provides, and examples of how to use `MultiLabelSoftMarginLoss` +/// with `torch::nn::MultiLabelSoftMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MultiLabelSoftMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the triplet loss given an input +/// tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater +/// than :math:`0`. This is used for measuring a relative similarity between +/// samples. A triplet is composed by `a`, `p` and `n` (i.e., `anchor`, +/// `positive examples` and `negative examples` respectively). The +/// shapes of all input tensors should be :math:`(N, D)`. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::TripletMarginLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// TripletMarginLoss +/// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false)); +/// ``` +struct TORCH_API TripletMarginLossImpl + : public Cloneable { + explicit TripletMarginLossImpl(TripletMarginLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `TripletMarginLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative); + + /// The options with which this `Module` was constructed. + TripletMarginLossOptions options; +}; + +/// A `ModuleHolder` subclass for `TripletMarginLossImpl`. +/// See the documentation for `TripletMarginLossImpl` class to learn what +/// methods it provides, and examples of how to use `TripletMarginLoss` with +/// `torch::nn::TripletMarginLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(TripletMarginLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TripletMarginWithDistanceLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the triplet loss given input +/// tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, +/// positive, and negative examples, respectively); and a nonnegative, +/// real-valued function +/// ("distance function") used to compute the relationships between the anchor +/// and positive example ("positive distance") and the anchor and negative +/// example ("negative distance"). +/// See +/// https://pytorch.org/docs/main/nn.html#torch.nn.TripletMarginWithDistanceLoss +/// to learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` +/// class to learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// TripletMarginWithDistanceLoss +/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); +/// ``` +struct TORCH_API TripletMarginWithDistanceLossImpl + : public Cloneable { + explicit TripletMarginWithDistanceLossImpl( + TripletMarginWithDistanceLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `TripletMarginWithDistanceLoss` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& anchor, + const Tensor& positive, + const Tensor& negative); + + /// The options with which this `Module` was constructed. + TripletMarginWithDistanceLossOptions options; +}; + +/// A `ModuleHolder` subclass for `TripletMarginWithDistanceLossImpl`. +/// See the documentation for `TripletMarginWithDistanceLossImpl` class to learn +/// what methods it provides, and examples of how to use +/// `TripletMarginWithDistanceLoss` with +/// `torch::nn::TripletMarginWithDistanceLossOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(TripletMarginWithDistanceLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CTCLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The Connectionist Temporal Classification loss. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.CTCLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CTCLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CTCLoss +/// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum)); +/// ``` +struct TORCH_API CTCLossImpl : public Cloneable { + explicit CTCLossImpl(CTCLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `CTCLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& log_probs, + const Tensor& targets, + const Tensor& input_lengths, + const Tensor& target_lengths); + + /// The options with which this `Module` was constructed. + CTCLossOptions options; +}; + +/// A `ModuleHolder` subclass for `CTCLossImpl`. +/// See the documentation for `CTCLossImpl` class to learn what methods it +/// provides, and examples of how to use `CTCLoss` with +/// `torch::nn::CTCLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(CTCLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PoissonNLLLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Negative log likelihood loss with Poisson distribution of target. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.PoissonNLLLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// PoissonNLLLoss +/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum)); +/// ``` +struct TORCH_API PoissonNLLLossImpl : public Cloneable { + explicit PoissonNLLLossImpl(PoissonNLLLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `PoissonNLLLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& log_input, const Tensor& targets); + + /// The options with which this `Module` was constructed. + PoissonNLLLossOptions options; +}; + +/// A `ModuleHolder` subclass for `PoissonNLLLossImpl`. +/// See the documentation for `PoissonNLLLossImpl` class to learn what methods +/// it provides, and examples of how to use `PoissonNLLLoss` with +/// `torch::nn::PoissonNLLLossOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(PoissonNLLLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MarginRankingLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that measures the loss given +/// inputs :math:`x1`, :math:`x2`, two 1D mini-batch `Tensors`, +/// and a label 1D mini-batch tensor :math:`y` (containing 1 or -1). +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MarginRankingLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MarginRankingLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MarginRankingLoss +/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)); +/// ``` +struct TORCH_API MarginRankingLossImpl + : public Cloneable { + explicit MarginRankingLossImpl(MarginRankingLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `MarginRankingLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward( + const Tensor& input1, + const Tensor& input2, + const Tensor& targets); + + /// The options with which this `Module` was constructed. + MarginRankingLossOptions options; +}; + +/// A `ModuleHolder` subclass for `MarginRankingLossImpl`. +/// See the documentation for `MarginRankingLossImpl` class to learn what +/// methods it provides, and examples of how to use `MarginRankingLoss` with +/// `torch::nn::MarginRankingLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(MarginRankingLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ NLLLoss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// The negative log likelihood loss. It is useful to train a classification +/// problem with `C` classes. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::NLLLossOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +struct TORCH_API NLLLossImpl : public Cloneable { + explicit NLLLossImpl(NLLLossOptions options_ = {}); + + /// Pretty prints the `NLLLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + void reset() override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + NLLLossOptions options; + + /// A manual rescaling weight given to to each class. + Tensor weight; +}; + +/// A `ModuleHolder` subclass for `NLLLossImpl`. +/// See the documentation for `NLLLossImpl` class to learn what methods it +/// provides, and examples of how to use `NLLLoss` with +/// `torch::nn::NLLLossOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(NLLLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossEntropyLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Creates a criterion that computes cross entropy loss between input and +/// target. See +/// https://pytorch.org/docs/main/nn.html#torch.nn.CrossEntropyLoss to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CrossEntropyLoss +/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +struct TORCH_API CrossEntropyLossImpl : public Cloneable { + explicit CrossEntropyLossImpl(CrossEntropyLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `CrossEntropyLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + CrossEntropyLossOptions options; + + /// A manual rescaling weight given to to each class. + Tensor weight; +}; + +/// A `ModuleHolder` subclass for `CrossEntropyLossImpl`. +/// See the documentation for `CrossEntropyLossImpl` class to learn what methods +/// it provides, and examples of how to use `CrossEntropyLoss` with +/// `torch::nn::CrossEntropyLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(CrossEntropyLoss); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BCEWithLogitsLoss +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// This loss combines a `Sigmoid` layer and the `BCELoss` in one single +/// class. This version is more numerically stable than using a plain `Sigmoid` +/// followed by a `BCELoss` as, by combining the operations into one layer, +/// we take advantage of the log-sum-exp trick for numerical stability. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.BCEWithLogitsLoss to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// BCEWithLogitsLoss +/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API BCEWithLogitsLossImpl + : public Cloneable { + explicit BCEWithLogitsLossImpl(BCEWithLogitsLossOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `BCEWithLogitsLoss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + BCEWithLogitsLossOptions options; + + /// A manual rescaling weight given to the loss of each batch element. + Tensor weight; + + /// A weight of positive examples. + Tensor pos_weight; +}; + +/// A `ModuleHolder` subclass for `BCEWithLogitsLossImpl`. +/// See the documentation for `BCEWithLogitsLossImpl` class to learn what +/// methods it provides, and examples of how to use `BCEWithLogitsLoss` with +/// `torch::nn::BCEWithLogitsLossOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(BCEWithLogitsLoss); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..7fe0396319d7b6c2336097c198fc0017e2e70563 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h @@ -0,0 +1,197 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Layer Normalization over a mini-batch of inputs as described in +/// the paper `Layer Normalization`_ . +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LayerNorm to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LayerNormOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LayerNorm model(LayerNormOptions({2, +/// 2}).elementwise_affine(false).eps(2e-5)); +/// ``` +class TORCH_API LayerNormImpl : public torch::nn::Cloneable { + public: + LayerNormImpl(std::vector normalized_shape) + : LayerNormImpl(LayerNormOptions(std::move(normalized_shape))) {} + explicit LayerNormImpl(LayerNormOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `LayerNorm` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Applies layer normalization over a mini-batch of inputs as described in + /// the paper `Layer Normalization`_ . + /// + /// The mean and standard-deviation are calculated separately over the last + /// certain number dimensions which have to be of the shape specified by + /// input `normalized_shape`. + /// + /// `Layer Normalization`: https://arxiv.org/abs/1607.06450 + Tensor forward(const Tensor& input); + + /// The options with which this module was constructed. + LayerNormOptions options; + + /// The learned weight. + /// Initialized to ones if the `elementwise_affine` option is set to `true` + /// upon construction. + Tensor weight; + + /// The learned bias. + /// Initialized to zeros `elementwise_affine` option is set to `true` upon + /// construction. + Tensor bias; +}; + +/// A `ModuleHolder` subclass for `LayerNormImpl`. +/// See the documentation for `LayerNormImpl` class to learn what methods it +/// provides, and examples of how to use `LayerNorm` with +/// `torch::nn::LayerNormOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LayerNorm); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies local response normalization over an input signal composed +/// of several input planes, where channels occupy the second dimension. +/// Applies normalization across channels. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LocalResponseNorm to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LocalResponseNormOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LocalResponseNorm +/// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); +/// ``` +class TORCH_API LocalResponseNormImpl + : public Cloneable { + public: + LocalResponseNormImpl(int64_t size) + : LocalResponseNormImpl(LocalResponseNormOptions(size)) {} + explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_); + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `LocalResponseNormImpl` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + LocalResponseNormOptions options; +}; + +/// A `ModuleHolder` subclass for `LocalResponseNormImpl`. +/// See the documentation for `LocalResponseNormImpl` class to learn what +/// methods it provides, and examples of how to use `LocalResponseNorm` with +/// `torch::nn::LocalResponseNormOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(LocalResponseNorm); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10)); +/// ``` +class TORCH_API CrossMapLRN2dImpl + : public torch::nn::Cloneable { + public: + CrossMapLRN2dImpl(int64_t size) + : CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {} + explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_) + : options(options_) {} + + void reset() override; + + /// Pretty prints the `CrossMapLRN2d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + torch::Tensor forward(const torch::Tensor& input); + + CrossMapLRN2dOptions options; +}; + +/// A `ModuleHolder` subclass for `CrossMapLRN2dImpl`. +/// See the documentation for `CrossMapLRN2dImpl` class to learn what methods it +/// provides, and examples of how to use `CrossMapLRN2d` with +/// `torch::nn::CrossMapLRN2dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(CrossMapLRN2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies Group Normalization over a mini-batch of inputs as described in +/// the paper `Group Normalization`_ . +/// See https://pytorch.org/docs/main/nn.html#torch.nn.GroupNorm to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::GroupNormOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false)); +/// ``` +class TORCH_API GroupNormImpl : public torch::nn::Cloneable { + public: + GroupNormImpl(int64_t num_groups, int64_t num_channels) + : GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {} + explicit GroupNormImpl(const GroupNormOptions& options_); + + void reset() override; + + void reset_parameters(); + + /// Pretty prints the `GroupNorm` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this module was constructed. + GroupNormOptions options; + + /// The learned weight. + Tensor weight; + + /// The learned bias. + Tensor bias; +}; + +/// A `ModuleHolder` subclass for `GroupNormImpl`. +/// See the documentation for `GroupNormImpl` class to learn what methods it +/// provides, and examples of how to use `GroupNorm` with +/// `torch::nn::GroupNormOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(GroupNorm); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/padding.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/padding.h new file mode 100644 index 0000000000000000000000000000000000000000..855608438ce0b5db94a56953a2cb1c1077e0038a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/padding.h @@ -0,0 +1,376 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch::nn { + +/// Base class for all (dimension-specialized) ReflectionPad modules. +template +class TORCH_API ReflectionPadImpl : public torch::nn::Cloneable { + public: + ReflectionPadImpl(ExpandingArray padding) + : ReflectionPadImpl(ReflectionPadOptions(padding)) {} + explicit ReflectionPadImpl(const ReflectionPadOptions& options_); + + void reset() override; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `ReflectionPad{1,2}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ReflectionPadOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ReflectionPad over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReflectionPad1d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReflectionPad1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReflectionPad1d model(ReflectionPad1dOptions({3, 1})); +/// ``` +class TORCH_API ReflectionPad1dImpl + : public ReflectionPadImpl<1, ReflectionPad1dImpl> { + public: + using ReflectionPadImpl<1, ReflectionPad1dImpl>::ReflectionPadImpl; +}; + +/// A `ModuleHolder` subclass for `ReflectionPad1dImpl`. +/// See the documentation for `ReflectionPad1dImpl` class to learn what methods +/// it provides, and examples of how to use `ReflectionPad1d` with +/// `torch::nn::ReflectionPad1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ReflectionPad1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ReflectionPad over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReflectionPad2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReflectionPad2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReflectionPad2d model(ReflectionPad2dOptions({1, 1, 2, 0})); +/// ``` +class TORCH_API ReflectionPad2dImpl + : public ReflectionPadImpl<2, ReflectionPad2dImpl> { + public: + using ReflectionPadImpl<2, ReflectionPad2dImpl>::ReflectionPadImpl; +}; + +/// A `ModuleHolder` subclass for `ReflectionPad2dImpl`. +/// See the documentation for `ReflectionPad2dImpl` class to learn what methods +/// it provides, and examples of how to use `ReflectionPad2d` with +/// `torch::nn::ReflectionPad2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ReflectionPad2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReflectionPad3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ReflectionPad over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReflectionPad3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReflectionPad3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReflectionPad3d model(ReflectionPad3dOptions(1)); +/// ReflectionPad3d model(ReflectionPad3dOptions({1, 1, 2, 0, 1, 2})); +/// ``` +class TORCH_API ReflectionPad3dImpl + : public ReflectionPadImpl<3, ReflectionPad3dImpl> { + public: + using ReflectionPadImpl<3, ReflectionPad3dImpl>::ReflectionPadImpl; +}; + +/// A `ModuleHolder` subclass for `ReflectionPad3dImpl`. +/// See the documentation for `ReflectionPad3dImpl` class to learn what methods +/// it provides, and examples of how to use `ReflectionPad3d` with +/// `torch::nn::ReflectionPad3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ReflectionPad3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) ReplicationPad modules. +template +class TORCH_API ReplicationPadImpl : public torch::nn::Cloneable { + public: + ReplicationPadImpl(ExpandingArray padding) + : ReplicationPadImpl(ReplicationPadOptions(padding)) {} + explicit ReplicationPadImpl(const ReplicationPadOptions& options_); + + void reset() override; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `ReplicationPad{1,2}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ReplicationPadOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad1d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ReplicationPad over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReplicationPad1d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReplicationPad1dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReplicationPad1d model(ReplicationPad1dOptions({3, 1})); +/// ``` +class TORCH_API ReplicationPad1dImpl + : public ReplicationPadImpl<1, ReplicationPad1dImpl> { + public: + using ReplicationPadImpl<1, ReplicationPad1dImpl>::ReplicationPadImpl; +}; + +/// A `ModuleHolder` subclass for `ReplicationPad1dImpl`. +/// See the documentation for `ReplicationPad1dImpl` class to learn what methods +/// it provides, and examples of how to use `ReplicationPad1d` with +/// `torch::nn::ReplicationPad1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ReplicationPad1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ReplicationPad over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReplicationPad2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReplicationPad2dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReplicationPad2d model(ReplicationPad2dOptions({1, 1, 2, 0})); +/// ``` +class TORCH_API ReplicationPad2dImpl + : public ReplicationPadImpl<2, ReplicationPad2dImpl> { + public: + using ReplicationPadImpl<2, ReplicationPad2dImpl>::ReplicationPadImpl; +}; + +/// A `ModuleHolder` subclass for `ReplicationPad2dImpl`. +/// See the documentation for `ReplicationPad2dImpl` class to learn what methods +/// it provides, and examples of how to use `ReplicationPad2d` with +/// `torch::nn::ReplicationPad2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ReplicationPad2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReplicationPad3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ReplicationPad over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ReplicationPad3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ReplicationPad3dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ReplicationPad3d model(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2})); +/// ``` +class TORCH_API ReplicationPad3dImpl + : public ReplicationPadImpl<3, ReplicationPad3dImpl> { + public: + using ReplicationPadImpl<3, ReplicationPad3dImpl>::ReplicationPadImpl; +}; + +/// A `ModuleHolder` subclass for `ReplicationPad3dImpl`. +/// See the documentation for `ReplicationPad3dImpl` class to learn what methods +/// it provides, and examples of how to use `ReplicationPad3d` with +/// `torch::nn::ReplicationPad3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(ReplicationPad3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) ZeroPad modules. +template +class TORCH_API ZeroPadImpl : public torch::nn::Cloneable { + public: + ZeroPadImpl(ExpandingArray padding) + : ZeroPadImpl(ZeroPadOptions(padding)) {} + explicit ZeroPadImpl(const ZeroPadOptions& options_); + + void reset() override; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `ZeroPad{1,2}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ZeroPadOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ZeroPad1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Applies ZeroPad over a 1-D input. +class TORCH_API ZeroPad1dImpl : public ZeroPadImpl<1, ZeroPad1dImpl> { + public: + using ZeroPadImpl<1, ZeroPad1dImpl>::ZeroPadImpl; +}; + +/// A `ModuleHolder` subclass for `ZeroPad1dImpl`. +/// See the documentation for `ZeroPad1dImpl` class to learn what methods it +/// provides, and examples of how to use `ZeroPad1d` with +/// `torch::nn::ZeroPad1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(ZeroPad1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ZeroPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Applies ZeroPad over a 2-D input. +class TORCH_API ZeroPad2dImpl : public ZeroPadImpl<2, ZeroPad2dImpl> { + public: + using ZeroPadImpl<2, ZeroPad2dImpl>::ZeroPadImpl; +}; + +/// A `ModuleHolder` subclass for `ZeroPad2dImpl`. +/// See the documentation for `ZeroPad2dImpl` class to learn what methods it +/// provides, and examples of how to use `ZeroPad2d` with +/// `torch::nn::ZeroPad2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(ZeroPad2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ZeroPad3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Applies ZeroPad over a 3-D input. +class TORCH_API ZeroPad3dImpl : public ZeroPadImpl<3, ZeroPad3dImpl> { + public: + using ZeroPadImpl<3, ZeroPad3dImpl>::ZeroPadImpl; +}; + +/// A `ModuleHolder` subclass for `ZeroPad3dImpl`. +/// See the documentation for `ZeroPad3dImpl` class to learn what methods it +/// provides, and examples of how to use `ZeroPad3d` with +/// `torch::nn::ZeroPad3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(ZeroPad3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) ConstantPad modules. +template +class TORCH_API ConstantPadImpl : public torch::nn::Cloneable { + public: + ConstantPadImpl(ExpandingArray padding, double value) + : ConstantPadImpl(ConstantPadOptions(padding, value)) {} + explicit ConstantPadImpl(const ConstantPadOptions& options_); + + void reset() override; + + Tensor forward(const Tensor& input); + + /// Pretty prints the `ConstantPad{1,2}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + ConstantPadOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ConstantPad over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ConstantPad1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ConstantPad1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ConstantPad1d model(ConstantPad1dOptions({3, 1}, 3.5)); +/// ``` +class TORCH_API ConstantPad1dImpl + : public ConstantPadImpl<1, ConstantPad1dImpl> { + public: + using ConstantPadImpl<1, ConstantPad1dImpl>::ConstantPadImpl; +}; + +/// A `ModuleHolder` subclass for `ConstantPad1dImpl`. +/// See the documentation for `ConstantPad1dImpl` class to learn what methods it +/// provides, and examples of how to use `ConstantPad1d` with +/// `torch::nn::ConstantPad1dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(ConstantPad1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ConstantPad over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ConstantPad2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ConstantPad2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ConstantPad2d model(ConstantPad2dOptions({3, 0, 2, 1}, 3.5)); +/// ``` +class TORCH_API ConstantPad2dImpl + : public ConstantPadImpl<2, ConstantPad2dImpl> { + public: + using ConstantPadImpl<2, ConstantPad2dImpl>::ConstantPadImpl; +}; + +/// A `ModuleHolder` subclass for `ConstantPad2dImpl`. +/// See the documentation for `ConstantPad2dImpl` class to learn what methods it +/// provides, and examples of how to use `ConstantPad2d` with +/// `torch::nn::ConstantPad2dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(ConstantPad2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConstantPad3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies ConstantPad over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.ConstantPad3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::ConstantPad3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5)); +/// ``` +class TORCH_API ConstantPad3dImpl + : public ConstantPadImpl<3, ConstantPad3dImpl> { + public: + using ConstantPadImpl<3, ConstantPad3dImpl>::ConstantPadImpl; +}; + +/// A `ModuleHolder` subclass for `ConstantPad3dImpl`. +/// See the documentation for `ConstantPad3dImpl` class to learn what methods it +/// provides, and examples of how to use `ConstantPad3d` with +/// `torch::nn::ConstantPad3dOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(ConstantPad3d); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h new file mode 100644 index 0000000000000000000000000000000000000000..ce981c3a1c341078d1f072c83fb371979da1e707 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` +/// to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an +/// upscale factor. See +/// https://pytorch.org/docs/main/nn.html#torch.nn.PixelShuffle to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// PixelShuffle model(PixelShuffleOptions(5)); +/// ``` +struct TORCH_API PixelShuffleImpl + : public torch::nn::Cloneable { + explicit PixelShuffleImpl(const PixelShuffleOptions& options_); + + /// Pretty prints the `PixelShuffle` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + void reset() override; + + /// The options with which this `Module` was constructed. + PixelShuffleOptions options; +}; + +/// A `ModuleHolder` subclass for `PixelShuffleImpl`. +/// See the documentation for `PixelShuffleImpl` class to learn what methods it +/// provides, and examples of how to use `PixelShuffle` with +/// `torch::nn::PixelShuffleOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(PixelShuffle); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelUnshuffle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Reverses the PixelShuffle operation by rearranging elements in a tensor of +/// shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape :math:`(*, +/// C \times r^2, H, W)`, where r is a downscale factor. See +/// https://pytorch.org/docs/main/nn.html#torch.nn.PixelUnshuffle to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// PixelUnshuffle model(PixelUnshuffleOptions(5)); +/// ``` +struct TORCH_API PixelUnshuffleImpl + : public torch::nn::Cloneable { + explicit PixelUnshuffleImpl(const PixelUnshuffleOptions& options_); + + /// Pretty prints the `PixelUnshuffle` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + void reset() override; + + /// The options with which this `Module` was constructed. + PixelUnshuffleOptions options; +}; + +/// A `ModuleHolder` subclass for `PixelUnshuffleImpl`. +/// See the documentation for `PixelUnshuffleImpl` class to learn what methods +/// it provides, and examples of how to use `PixelUnshuffle` with +/// `torch::nn::PixelUnshuffleOptions`. See the documentation for `ModuleHolder` +/// to learn about PyTorch's module storage semantics. +TORCH_MODULE(PixelUnshuffle); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pooling.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..17ed12f4cc037ea0249f20edd3f94e509ceadd19 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/pooling.h @@ -0,0 +1,777 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch::nn { + +/// Base class for all (dimension-specialized) avgpool modules. +template +class TORCH_API AvgPoolImpl : public torch::nn::Cloneable { + public: + AvgPoolImpl(ExpandingArray kernel_size) + : AvgPoolImpl(AvgPoolOptions(kernel_size)) {} + explicit AvgPoolImpl(const AvgPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `AvgPool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + AvgPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AvgPool1dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AvgPool1d model(AvgPool1dOptions(3).stride(2)); +/// ``` +class TORCH_API AvgPool1dImpl : public AvgPoolImpl<1, AvgPool1dImpl> { + public: + using AvgPoolImpl<1, AvgPool1dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool1dImpl`. +/// See the documentation for `AvgPool1dImpl` class to learn what methods it +/// provides, and examples of how to use `AvgPool1d` with +/// `torch::nn::AvgPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(AvgPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AvgPool2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2})); +/// ``` +class TORCH_API AvgPool2dImpl : public AvgPoolImpl<2, AvgPool2dImpl> { + public: + using AvgPoolImpl<2, AvgPool2dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool2dImpl`. +/// See the documentation for `AvgPool2dImpl` class to learn what methods it +/// provides, and examples of how to use `AvgPool2d` with +/// `torch::nn::AvgPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(AvgPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AvgPool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AvgPool3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AvgPool3d model(AvgPool3dOptions(5).stride(2)); +/// ``` +class TORCH_API AvgPool3dImpl : public AvgPoolImpl<3, AvgPool3dImpl> { + public: + using AvgPoolImpl<3, AvgPool3dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool3dImpl`. +/// See the documentation for `AvgPool3dImpl` class to learn what methods it +/// provides, and examples of how to use `AvgPool3d` with +/// `torch::nn::AvgPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(AvgPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) maxpool modules. +template +class TORCH_API MaxPoolImpl : public torch::nn::Cloneable { + public: + MaxPoolImpl(ExpandingArray kernel_size) + : MaxPoolImpl(MaxPoolOptions(kernel_size)) {} + explicit MaxPoolImpl(const MaxPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `MaxPool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + MaxPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxPool1dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxPool1d model(MaxPool1dOptions(3).stride(2)); +/// ``` +class TORCH_API MaxPool1dImpl : public MaxPoolImpl<1, MaxPool1dImpl> { + public: + using MaxPoolImpl<1, MaxPool1dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool1d` later. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool1dImpl`. +/// See the documentation for `MaxPool1dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxPool1d` with +/// `torch::nn::MaxPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxPool2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2})); +/// ``` +class TORCH_API MaxPool2dImpl : public MaxPoolImpl<2, MaxPool2dImpl> { + public: + using MaxPoolImpl<2, MaxPool2dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool2d` later. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool2dImpl`. +/// See the documentation for `MaxPool2dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxPool2d` with +/// `torch::nn::MaxPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxPool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxPool3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxPool3d model(MaxPool3dOptions(3).stride(2)); +/// ``` +class TORCH_API MaxPool3dImpl : public MaxPoolImpl<3, MaxPool3dImpl> { + public: + using MaxPoolImpl<3, MaxPool3dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool3d` later. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool3dImpl`. +/// See the documentation for `MaxPool3dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxPool3d` with +/// `torch::nn::MaxPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) adaptive maxpool modules. +template +class TORCH_API AdaptiveMaxPoolImpl : public torch::nn::Cloneable { + public: + AdaptiveMaxPoolImpl(output_size_t output_size) + : AdaptiveMaxPoolImpl( + AdaptiveMaxPoolOptions(output_size)) {} + explicit AdaptiveMaxPoolImpl( + const AdaptiveMaxPoolOptions& options_) + : options(options_) {} + + void reset() override {} + + /// Pretty prints the `AdaptiveMaxPool{1,2,3}d` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::AdaptiveMaxPool" << D << "d" + << "(output_size=" << options.output_size() << ")"; + } + + /// The options with which this `Module` was constructed. + AdaptiveMaxPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive maxpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool1d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveMaxPool1dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool1d model(AdaptiveMaxPool1dOptions(3)); +/// ``` +class TORCH_API AdaptiveMaxPool1dImpl + : public AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl> { + public: + using AdaptiveMaxPoolImpl<1, ExpandingArray<1>, AdaptiveMaxPool1dImpl>:: + AdaptiveMaxPoolImpl; + + Tensor forward(const Tensor& input); + + /// Returns the indices along with the outputs. + /// Useful to pass to nn.MaxUnpool1d. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveMaxPool1dImpl`. +/// See the documentation for `AdaptiveMaxPool1dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool1d` with +/// `torch::nn::AdaptiveMaxPool1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveMaxPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive maxpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveMaxPool2dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2})); +/// ``` +class TORCH_API AdaptiveMaxPool2dImpl : public AdaptiveMaxPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveMaxPool2dImpl> { + public: + using AdaptiveMaxPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveMaxPool2dImpl>::AdaptiveMaxPoolImpl; + + Tensor forward(const Tensor& input); + + /// Returns the indices along with the outputs. + /// Useful to pass to nn.MaxUnpool2d. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveMaxPool2dImpl`. +/// See the documentation for `AdaptiveMaxPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool2d` with +/// `torch::nn::AdaptiveMaxPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveMaxPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveMaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive maxpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveMaxPool3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveMaxPool3dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool3d model(AdaptiveMaxPool3dOptions(3)); +/// ``` +class TORCH_API AdaptiveMaxPool3dImpl : public AdaptiveMaxPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveMaxPool3dImpl> { + public: + using AdaptiveMaxPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveMaxPool3dImpl>::AdaptiveMaxPoolImpl; + + Tensor forward(const Tensor& input); + + /// Returns the indices along with the outputs. + /// Useful to pass to nn.MaxUnpool3d. + std::tuple forward_with_indices(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveMaxPool3dImpl`. +/// See the documentation for `AdaptiveMaxPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveMaxPool3d` with +/// `torch::nn::AdaptiveMaxPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveMaxPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) adaptive avgpool modules. +template +class TORCH_API AdaptiveAvgPoolImpl : public torch::nn::Cloneable { + public: + AdaptiveAvgPoolImpl(output_size_t output_size) + : AdaptiveAvgPoolImpl( + AdaptiveAvgPoolOptions(output_size)) {} + explicit AdaptiveAvgPoolImpl( + const AdaptiveAvgPoolOptions& options_) + : options(options_) {} + + void reset() override {} + + /// Pretty prints the `AdaptiveAvgPool{1,2,3}d` module into the given + /// `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::AdaptiveAvgPool" << D << "d" + << "(output_size=" << options.output_size() << ")"; + } + + /// The options with which this `Module` was constructed. + AdaptiveAvgPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive avgpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool1d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool1d model(AdaptiveAvgPool1dOptions(5)); +/// ``` +class TORCH_API AdaptiveAvgPool1dImpl + : public AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl> { + public: + using AdaptiveAvgPoolImpl<1, ExpandingArray<1>, AdaptiveAvgPool1dImpl>:: + AdaptiveAvgPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveAvgPool1dImpl`. +/// See the documentation for `AdaptiveAvgPool1dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool1d` with +/// `torch::nn::AdaptiveAvgPool1dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveAvgPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive avgpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2})); +/// ``` +class TORCH_API AdaptiveAvgPool2dImpl : public AdaptiveAvgPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveAvgPool2dImpl> { + public: + using AdaptiveAvgPoolImpl< + 2, + ExpandingArrayWithOptionalElem<2>, + AdaptiveAvgPool2dImpl>::AdaptiveAvgPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveAvgPool2dImpl`. +/// See the documentation for `AdaptiveAvgPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool2d` with +/// `torch::nn::AdaptiveAvgPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveAvgPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveAvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies adaptive avgpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveAvgPool3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool3d model(AdaptiveAvgPool3dOptions(3)); +/// ``` +class TORCH_API AdaptiveAvgPool3dImpl : public AdaptiveAvgPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveAvgPool3dImpl> { + public: + using AdaptiveAvgPoolImpl< + 3, + ExpandingArrayWithOptionalElem<3>, + AdaptiveAvgPool3dImpl>::AdaptiveAvgPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AdaptiveAvgPool3dImpl`. +/// See the documentation for `AdaptiveAvgPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `AdaptiveAvgPool3d` with +/// `torch::nn::AdaptiveAvgPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(AdaptiveAvgPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) maxunpool modules. +template +class TORCH_API MaxUnpoolImpl : public torch::nn::Cloneable { + public: + MaxUnpoolImpl(ExpandingArray kernel_size) + : MaxUnpoolImpl(MaxUnpoolOptions(kernel_size)) {} + explicit MaxUnpoolImpl(const MaxUnpoolOptions& options_); + + void reset() override; + + /// Pretty prints the `MaxUnpool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + MaxUnpoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxunpool over a 1-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxUnpool1dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxUnpool1d model(MaxUnpool1dOptions(3).stride(2).padding(1)); +/// ``` +class TORCH_API MaxUnpool1dImpl : public MaxUnpoolImpl<1, MaxUnpool1dImpl> { + public: + using MaxUnpoolImpl<1, MaxUnpool1dImpl>::MaxUnpoolImpl; + Tensor forward( + const Tensor& input, + const Tensor& indices, + const std::optional>& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) +}; + +/// A `ModuleHolder` subclass for `MaxUnpool1dImpl`. +/// See the documentation for `MaxUnpool1dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxUnpool1d` with +/// `torch::nn::MaxUnpool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxUnpool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxunpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxUnpool2dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxUnpool2d model(MaxUnpool2dOptions(3).stride(2).padding(1)); +/// ``` +class TORCH_API MaxUnpool2dImpl : public MaxUnpoolImpl<2, MaxUnpool2dImpl> { + public: + using MaxUnpoolImpl<2, MaxUnpool2dImpl>::MaxUnpoolImpl; + Tensor forward( + const Tensor& input, + const Tensor& indices, + const std::optional>& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) +}; + +/// A `ModuleHolder` subclass for `MaxUnpool2dImpl`. +/// See the documentation for `MaxUnpool2dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxUnpool2d` with +/// `torch::nn::MaxUnpool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxUnpool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxUnpool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxunpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.MaxUnpool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::MaxUnpool3dOptions` class to learn +/// what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// MaxUnpool3d model(MaxUnpool3dOptions(3).stride(2).padding(1)); +/// ``` +class TORCH_API MaxUnpool3dImpl : public MaxUnpoolImpl<3, MaxUnpool3dImpl> { + public: + using MaxUnpoolImpl<3, MaxUnpool3dImpl>::MaxUnpoolImpl; + Tensor forward( + const Tensor& input, + const Tensor& indices, + const std::optional>& output_size = std::nullopt); + + protected: + FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) +}; + +/// A `ModuleHolder` subclass for `MaxUnpool3dImpl`. +/// See the documentation for `MaxUnpool3dImpl` class to learn what methods it +/// provides, and examples of how to use `MaxUnpool3d` with +/// `torch::nn::MaxUnpool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(MaxUnpool3d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool2d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies fractional maxpool over a 2-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.FractionalMaxPool2d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::FractionalMaxPool2dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// FractionalMaxPool2d model(FractionalMaxPool2dOptions(5).output_size(1)); +/// ``` +class TORCH_API FractionalMaxPool2dImpl + : public torch::nn::Cloneable { + public: + FractionalMaxPool2dImpl(ExpandingArray<2> kernel_size) + : FractionalMaxPool2dImpl(FractionalMaxPool2dOptions(kernel_size)) {} + explicit FractionalMaxPool2dImpl(FractionalMaxPool2dOptions options_); + + void reset() override; + + /// Pretty prints the `FractionalMaxPool2d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool2d` later. + std::tuple forward_with_indices(const Tensor& input); + + /// The options with which this `Module` was constructed. + FractionalMaxPool2dOptions options; + + Tensor _random_samples; +}; + +/// A `ModuleHolder` subclass for `FractionalMaxPool2dImpl`. +/// See the documentation for `FractionalMaxPool2dImpl` class to learn what +/// methods it provides, and examples of how to use `FractionalMaxPool2d` with +/// `torch::nn::FractionalMaxPool2dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(FractionalMaxPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FractionalMaxPool3d +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies fractional maxpool over a 3-D input. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.FractionalMaxPool3d to +/// learn about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::FractionalMaxPool3dOptions` class to +/// learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// FractionalMaxPool3d model(FractionalMaxPool3dOptions(5).output_size(1)); +/// ``` +class TORCH_API FractionalMaxPool3dImpl + : public torch::nn::Cloneable { + public: + FractionalMaxPool3dImpl(ExpandingArray<3> kernel_size) + : FractionalMaxPool3dImpl(FractionalMaxPool3dOptions(kernel_size)) {} + explicit FractionalMaxPool3dImpl(FractionalMaxPool3dOptions options_); + + void reset() override; + + /// Pretty prints the `FractionalMaxPool3d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// Returns the outputs and the indices of the max values. + /// Useful for `torch::nn::MaxUnpool3d` later. + std::tuple forward_with_indices(const Tensor& input); + + /// The options with which this `Module` was constructed. + FractionalMaxPool3dOptions options; + + Tensor _random_samples; +}; + +/// A `ModuleHolder` subclass for `FractionalMaxPool3dImpl`. +/// See the documentation for `FractionalMaxPool3dImpl` class to learn what +/// methods it provides, and examples of how to use `FractionalMaxPool3d` with +/// `torch::nn::FractionalMaxPool3dOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(FractionalMaxPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) lppool modules. +template +class TORCH_API LPPoolImpl : public torch::nn::Cloneable { + public: + LPPoolImpl(double norm_type, ExpandingArray kernel_size) + : LPPoolImpl(LPPoolOptions(norm_type, kernel_size)) {} + explicit LPPoolImpl(const LPPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `LPPool{1,2}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + LPPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LPPool1d function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool1d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LPPool1dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LPPool1d model(LPPool1dOptions(1, 2).stride(5).ceil_mode(true)); +/// ``` +class TORCH_API LPPool1dImpl : public LPPoolImpl<1, LPPool1dImpl> { + public: + using LPPoolImpl<1, LPPool1dImpl>::LPPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `LPPool1dImpl`. +/// See the documentation for `LPPool1dImpl` class to learn what methods it +/// provides, and examples of how to use `LPPool1d` with +/// `torch::nn::LPPool1dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LPPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LPPool2d function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool2d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LPPool2dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LPPool2d model(LPPool2dOptions(1, std::vector({3, 4})).stride({5, +/// 6}).ceil_mode(true)); +/// ``` +class TORCH_API LPPool2dImpl : public LPPoolImpl<2, LPPool2dImpl> { + public: + using LPPoolImpl<2, LPPool2dImpl>::LPPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `LPPool2dImpl`. +/// See the documentation for `LPPool2dImpl` class to learn what methods it +/// provides, and examples of how to use `LPPool2d` with +/// `torch::nn::LPPool2dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LPPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LPPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LPPool3d function element-wise. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LPPool3d to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LPPool3dOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LPPool3d model(LPPool3dOptions(1, std::vector({3, 4, 5})).stride( +/// {5, 6, 7}).ceil_mode(true)); +/// ``` +class TORCH_API LPPool3dImpl : public LPPoolImpl<3, LPPool3dImpl> { + public: + using LPPoolImpl<3, LPPool3dImpl>::LPPoolImpl; + + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `LPPool3dImpl`. +/// See the documentation for `LPPool3dImpl` class to learn what methods it +/// provides, and examples of how to use `LPPool3d` with +/// `torch::nn::LPPool3dOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LPPool3d); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/rnn.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/rnn.h new file mode 100644 index 0000000000000000000000000000000000000000..4d30ea149ba3fb11e7cf8c248185fa7f5be65952 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -0,0 +1,399 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace torch::nn { + +namespace detail { +/// Base class for all RNN implementations (intended for code sharing). +template +class TORCH_API RNNImplBase : public torch::nn::Cloneable { + public: + explicit RNNImplBase(const RNNOptionsBase& options_); + + /// Initializes the parameters of the RNN module. + void reset() override; + + void reset_parameters(); + + /// Overrides `nn::Module::to()` to call `flatten_parameters()` after the + /// original operation. + void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) + override; + void to(torch::Dtype dtype, bool non_blocking = false) override; + void to(torch::Device device, bool non_blocking = false) override; + + /// Pretty prints the RNN module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// Modifies the internal storage of weights for optimization purposes. + /// + /// On CPU, this method should be called if any of the weight or bias vectors + /// are changed (i.e. weights are added or removed). On GPU, it should be + /// called __any time the storage of any parameter is modified__, e.g. any + /// time a parameter is assigned a new value. This allows using the fast path + /// in cuDNN implementations of respective RNN `forward()` methods. It is + /// called once upon construction, inside `reset()`. + void flatten_parameters(); + + std::vector all_weights() const; + + /// The RNN's options. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + RNNOptionsBase options_base; + + protected: + // Resets flat_weights_ + // Note: be v. careful before removing this, as 3rd party device types + // likely rely on this behavior to properly .to() modules like LSTM. + void reset_flat_weights(); + + void check_input(const Tensor& input, const Tensor& batch_sizes) const; + + std::tuple get_expected_hidden_size( + const Tensor& input, + const Tensor& batch_sizes) const; + + void check_hidden_size( + const Tensor& hx, + std::tuple expected_hidden_size, + std::string msg = "Expected hidden size {1}, got {2}") const; + + void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) + const; + + Tensor permute_hidden(Tensor hx, const Tensor& permutation) const; + + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector flat_weights_names_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector> all_weights_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::vector flat_weights_; +}; +} // namespace detail + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A multi-layer Elman RNN module with Tanh or ReLU activation. +/// See https://pytorch.org/docs/main/generated/torch.nn.RNN.html to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::RNNOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// RNN model(RNNOptions(128, +/// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); +/// ``` +class TORCH_API RNNImpl : public detail::RNNImplBase { + public: + RNNImpl(int64_t input_size, int64_t hidden_size) + : RNNImpl(RNNOptions(input_size, hidden_size)) {} + explicit RNNImpl(const RNNOptions& options_); + + std::tuple forward(const Tensor& input, Tensor hx = {}); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) + + public: + std::tuple + forward_with_packed_input( + const torch::nn::utils::rnn::PackedSequence& packed_input, + Tensor hx = {}); + + RNNOptions options; + + protected: + std::tuple forward_helper( + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + Tensor hx); +}; + +/// A `ModuleHolder` subclass for `RNNImpl`. +/// See the documentation for `RNNImpl` class to learn what methods it +/// provides, and examples of how to use `RNN` with `torch::nn::RNNOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(RNN); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A multi-layer long-short-term-memory (LSTM) module. +/// See https://pytorch.org/docs/main/generated/torch.nn.LSTM.html to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LSTMOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LSTM model(LSTMOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); +/// ``` +class TORCH_API LSTMImpl : public detail::RNNImplBase { + public: + LSTMImpl(int64_t input_size, int64_t hidden_size) + : LSTMImpl(LSTMOptions(input_size, hidden_size)) {} + explicit LSTMImpl(const LSTMOptions& options_); + + std::tuple> forward( + const Tensor& input, + std::optional> hx_opt = {}); + + protected: + FORWARD_HAS_DEFAULT_ARGS( + {1, AnyValue(std::optional>())}) + + public: + std::tuple> + forward_with_packed_input( + const torch::nn::utils::rnn::PackedSequence& packed_input, + std::optional> hx_opt = {}); + + LSTMOptions options; + + protected: + void check_forward_args( + const Tensor& input, + std::tuple hidden, + const Tensor& batch_sizes) const; + + std::tuple get_expected_cell_size( + const Tensor& input, + const Tensor& batch_sizes) const; + + std::tuple permute_hidden( + std::tuple hx, + const Tensor& permutation) const; + + std::tuple> forward_helper( + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + std::optional> hx_opt); +}; + +/// A `ModuleHolder` subclass for `LSTMImpl`. +/// See the documentation for `LSTMImpl` class to learn what methods it +/// provides, and examples of how to use `LSTM` with `torch::nn::LSTMOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(LSTM); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A multi-layer gated recurrent unit (GRU) module. +/// See https://pytorch.org/docs/main/generated/torch.nn.GRU.html to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::GRUOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// GRU model(GRUOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); +/// ``` +class TORCH_API GRUImpl : public detail::RNNImplBase { + public: + GRUImpl(int64_t input_size, int64_t hidden_size) + : GRUImpl(GRUOptions(input_size, hidden_size)) {} + explicit GRUImpl(const GRUOptions& options_); + + std::tuple forward(const Tensor& input, Tensor hx = {}); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())}) + + public: + std::tuple + forward_with_packed_input( + const torch::nn::utils::rnn::PackedSequence& packed_input, + Tensor hx = {}); + + GRUOptions options; + + protected: + std::tuple forward_helper( + const Tensor& input, + const Tensor& batch_sizes, + const Tensor& sorted_indices, + int64_t max_batch_size, + Tensor hx); +}; + +/// A `ModuleHolder` subclass for `GRUImpl`. +/// See the documentation for `GRUImpl` class to learn what methods it +/// provides, and examples of how to use `GRU` with `torch::nn::GRUOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(GRU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +namespace detail { +/// Base class for all RNNCell implementations (intended for code sharing). +template +class TORCH_API RNNCellImplBase : public torch::nn::Cloneable { + public: + explicit RNNCellImplBase(const RNNCellOptionsBase& options_); + + /// Initializes the parameters of the RNNCell module. + void reset() override; + + void reset_parameters(); + + /// Pretty prints the RNN module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + RNNCellOptionsBase options_base; + + Tensor weight_ih; + Tensor weight_hh; + Tensor bias_ih; + Tensor bias_hh; + + protected: + void check_forward_input(const Tensor& input, const std::string& name) const; + virtual std::string get_nonlinearity_str() const; +}; +} // namespace detail + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// An Elman RNN cell with tanh or ReLU non-linearity. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.RNNCell to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::RNNCellOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// RNNCell model(RNNCellOptions(20, +/// 10).bias(false).nonlinearity(torch::kReLU)); +/// ``` +class TORCH_API RNNCellImpl : public detail::RNNCellImplBase { + public: + RNNCellImpl(int64_t input_size, int64_t hidden_size) + : RNNCellImpl(RNNCellOptions(input_size, hidden_size)) {} + explicit RNNCellImpl(const RNNCellOptions& options_); + + Tensor forward(const Tensor& input, const Tensor& hx = {}); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) + + public: + RNNCellOptions options; + + protected: + std::string get_nonlinearity_str() const override; +}; + +/// A `ModuleHolder` subclass for `RNNCellImpl`. +/// See the documentation for `RNNCellImpl` class to learn what methods it +/// provides, and examples of how to use `RNNCell` with +/// `torch::nn::RNNCellOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(RNNCell); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A long short-term memory (LSTM) cell. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.LSTMCell to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::LSTMCellOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// LSTMCell model(LSTMCellOptions(20, 10).bias(false)); +/// ``` +class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase { + public: + LSTMCellImpl(int64_t input_size, int64_t hidden_size) + : LSTMCellImpl(LSTMCellOptions(input_size, hidden_size)) {} + explicit LSTMCellImpl(const LSTMCellOptions& options_); + + std::tuple forward( + const Tensor& input, + std::optional> hx_opt = {}); + + protected: + FORWARD_HAS_DEFAULT_ARGS( + {1, AnyValue(std::optional>())}) + + public: + LSTMCellOptions options; +}; + +/// A `ModuleHolder` subclass for `LSTMCellImpl`. +/// See the documentation for `LSTMCellImpl` class to learn what methods it +/// provides, and examples of how to use `LSTMCell` with +/// `torch::nn::LSTMCellOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(LSTMCell); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A gated recurrent unit (GRU) cell. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.GRUCell to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::GRUCellOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// GRUCell model(GRUCellOptions(20, 10).bias(false)); +/// ``` +class TORCH_API GRUCellImpl : public detail::RNNCellImplBase { + public: + GRUCellImpl(int64_t input_size, int64_t hidden_size) + : GRUCellImpl(GRUCellOptions(input_size, hidden_size)) {} + explicit GRUCellImpl(const GRUCellOptions& options_); + + Tensor forward(const Tensor& input, const Tensor& hx = {}); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) + + public: + GRUCellOptions options; +}; + +/// A `ModuleHolder` subclass for `GRUCellImpl`. +/// See the documentation for `GRUCellImpl` class to learn what methods it +/// provides, and examples of how to use `GRUCell` with +/// `torch::nn::GRUCellOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(GRUCell); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformer.h new file mode 100644 index 0000000000000000000000000000000000000000..2f22f087bf518bad86dd198fcc7d46073f6de435 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformer.h @@ -0,0 +1,141 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// A transformer model. User is able to modify the attributes as needed. The +/// architecture is based on the paper "Attention Is All You Need". Ashish +/// Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N +/// Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. +/// In Advances in Neural Information Processing Systems, pages 6000-6010. +/// +/// See https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html to +/// learn about the exact behavior of this transformer model +/// +/// See the documentation for `torch::nn::Transformer` class to learn what +/// constructor arguments are supported for this encoder layer model +/// +/// Example: +/// ``` +/// Transformer trans(TransformerOptions(512, 8)); +/// ``` +class TORCH_API TransformerImpl : public Cloneable { + public: + explicit TransformerImpl(TransformerOptions options_); + + /// forward function for Transformer Module + /// Args: + /// src: the sequence to the encoder (required). + /// tgt: the sequence to the decoder (required). + /// src_mask: the additive mask for the src sequence (optional). + /// tgt_mask: the additive mask for the tgt sequence (optional). + /// memory_mask: the additive mask for the encoder output (optional). + /// src_key_padding_mask: the ByteTensor mask for src keys per batch + /// (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per + /// batch (optional). memory_key_padding_mask: the ByteTensor mask for + /// memory keys per batch (optional). + /// + /// Shape: + /// src: `(S, N, E)` + /// tgt: `(T, N, E)` + /// src_mask: `(S, S)` + /// tgt_mask: `(T, T)` + /// memory_mask: `(T, S)` + /// src_key_padding_mask: `(N, S)` + /// tgt_key_padding_mask: `(N, T)` + /// memory_key_padding_mask: `(N, S)` + /// + /// Note: + /// [src/tgt/memory]_mask ensures that position i is allowed to attend the + /// unmasked positions. If a ByteTensor is provided, the non-zero + /// positions are not allowed to attend while the zero positions will be + /// unchanged. If a BoolTensor is provided, positions with `True` are not + /// allowed to attend while `False` values will be unchanged. If a + /// FloatTensor is provided, it will be added to the attention weight. + /// + /// [src/tgt/memory]_key_padding_mask provides specified elements in the + /// key to be ignored by the attention. If a ByteTensor is provided, the + /// non-zero positions will be ignored while the zero positions will be + /// unchanged. If a BoolTensor is provided, the positions with the value + /// of `True` will be ignored while the position with the value of `False` + /// will be unchanged. + /// + /// output: `(T, N, E)` + /// + /// Note: + /// Due to the multi-head attention architecture in the transformer model, + /// the output sequence length of a transformer is same as the input + /// sequence (i.e. target) length of the decode. + /// + /// where + /// S is the source sequence length, + /// T is the target sequence length, + /// N is the batch size, + /// E is the feature number. + Tensor forward( + const Tensor& src, + const Tensor& tgt, + const Tensor& src_mask = {}, + const Tensor& tgt_mask = {}, + const Tensor& memory_mask = {}, + const Tensor& src_key_padding_mask = {}, + const Tensor& tgt_key_padding_mask = {}, + const Tensor& memory_key_padding_mask = {}); + + void reset() override; + + void reset_parameters(); + + /// Generate a square mask for the sequence. + /// The masked positions are filled with `-inf` in float type. + /// Unmasked positions are filled with `0.0` in float type. + /// Note: + /// 1. This function will always return a CPU tensor. + /// 2. This function requires the platform support IEEE754, since `-inf` is + /// guaranteed to + /// be valid only when IEEE754 is supported. If the platform doesn't + /// support IEEE754, this function will fill the mask with the smallest + /// float number instead of `-inf`, a one time warning will pop up as + /// well. + static Tensor generate_square_subsequent_mask(int64_t sz); + + protected: + FORWARD_HAS_DEFAULT_ARGS( + {2, AnyValue(Tensor())}, + {3, AnyValue(Tensor())}, + {4, AnyValue(Tensor())}, + {5, AnyValue(Tensor())}, + {6, AnyValue(Tensor())}, + {7, AnyValue(Tensor())}) + + public: + /// options with which this `Transformer` was constructed + TransformerOptions options; + + /// encoder module + AnyModule encoder; + + /// decoder module + AnyModule decoder; +}; + +/// A `ModuleHolder` subclass for `TransformerImpl`. +/// See the documentation for `TransformerImpl` class to learn what +/// methods it provides, and examples of how to use `Transformer` with +/// `torch::nn::TransformerOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Transformer); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformercoder.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformercoder.h new file mode 100644 index 0000000000000000000000000000000000000000..e06dd81b9234c5378b7dc03a4e20245207fefd26 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformercoder.h @@ -0,0 +1,152 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// TransformerEncoder module. +/// See +/// https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoder.html +/// to learn abouut the exact behavior of this encoder layer module. +/// +/// See the documentation for `torch::nn::TransformerEncoder` class to learn +/// what constructor arguments are supported for this encoder module. +/// +/// Example: +/// ``` +/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, +/// 8).dropout(0.1)); TransformerEncoder +/// encoder(TransformerEncoderOptions(encoderLayer, +/// 6).norm(LayerNorm(LayerNormOptions({2})))); +/// ``` +class TORCH_API TransformerEncoderImpl + : public Cloneable { + public: + TransformerEncoderImpl( + TransformerEncoderLayer encoder_layer, + int64_t num_layers) + : TransformerEncoderImpl( + TransformerEncoderOptions(std::move(encoder_layer), num_layers)) {} + explicit TransformerEncoderImpl(TransformerEncoderOptions options_); + + Tensor forward( + const Tensor& src, + const Tensor& src_mask = {}, + const Tensor& src_key_padding_mask = {}); + + void reset() override; + + void reset_parameters(); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) + + public: + /// options with which this `TransformerEncoder` was constructed + TransformerEncoderOptions options; + + /// module list that contains all the encoder layers + ModuleList layers = nullptr; + + /// optional normalization module + AnyModule norm; +}; + +/// A `ModuleHolder` subclass for `TransformerEncoderImpl`. +/// See the documentation for `TransformerEncoderImpl` class to learn what +/// methods it provides, and examples of how to use `TransformerEncoder` with +/// `torch::nn::TransformerEncoderOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(TransformerEncoder); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoder +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// TransformerDecoder is a stack of N decoder layers. +/// See +/// https://pytorch.org/docs/main/generated/torch.nn.TransformerDecoder.html +/// to learn abouut the exact behavior of this decoder module +/// +/// See the documentation for `torch::nn::TransformerDecoderOptions` class to +/// learn what constructor arguments are supported for this decoder module +/// +/// Example: +/// ``` +/// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512, +/// 8).dropout(0.1)); TransformerDecoder +/// transformer_decoder(TransformerDecoderOptions(decoder_layer, +/// 6).norm(LayerNorm(LayerNormOptions({2})))); const auto memory = +/// torch::rand({10, 32, 512}); const auto tgt = torch::rand({20, 32, 512}); +/// auto out = transformer_decoder(tgt, memory); +/// ``` +class TORCH_API TransformerDecoderImpl + : public Cloneable { + public: + TransformerDecoderImpl( + TransformerDecoderLayer decoder_layer, + int64_t num_layers) + : TransformerDecoderImpl( + TransformerDecoderOptions(std::move(decoder_layer), num_layers)) {} + explicit TransformerDecoderImpl(TransformerDecoderOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pass the inputs (and mask) through the decoder layer in turn. + /// Args: + /// tgt: the sequence to the decoder layer (required). + /// memory: the sequence from the last layer of the encoder (required). + /// tgt_mask: the mask for the tgt sequence (optional). + /// memory_mask: the mask for the memory sequence (optional). + /// tgt_key_padding_mask: the mask for the tgt keys per batch + /// (optional). memory_key_padding_mask: the mask for the memory keys + /// per batch (optional). + Tensor forward( + const Tensor& tgt, + const Tensor& memory, + const Tensor& tgt_mask = {}, + const Tensor& memory_mask = {}, + const Tensor& tgt_key_padding_mask = {}, + const Tensor& memory_key_padding_mask = {}); + + /// The options used to configure this module. + TransformerDecoderOptions options; + + /// Cloned layers of decoder layers + ModuleList layers{nullptr}; + + /// optional layer normalization module + AnyModule norm; + + protected: + FORWARD_HAS_DEFAULT_ARGS( + {2, AnyValue(Tensor())}, + {3, AnyValue(Tensor())}, + {4, AnyValue(Tensor())}, + {5, AnyValue(Tensor())}) +}; + +/// A `ModuleHolder` subclass for `TransformerDecoderImpl`. +/// See the documentation for `TransformerDecoderImpl` class to learn what +/// methods it provides, and examples of how to use `TransformerDecoder` with +/// `torch::nn::TransformerDecoderOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(TransformerDecoder); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformerlayer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformerlayer.h new file mode 100644 index 0000000000000000000000000000000000000000..74f1143e5c1637acf2933d3a8e3cdd2bef97efb7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/transformerlayer.h @@ -0,0 +1,193 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// TransformerEncoderLayer module. +/// See +/// https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoderLayer.html +/// to learn abouut the exact behavior of this encoder layer model +/// +/// See the documentation for `torch::nn::TransformerEncoderLayer` class to +/// learn what constructor arguments are supported for this encoder layer model +/// +/// Example: +/// ``` +/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, +/// 8).dropout(0.1)); +/// ``` +class TORCH_API TransformerEncoderLayerImpl + : public Cloneable { + public: + TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead) + : TransformerEncoderLayerImpl( + TransformerEncoderLayerOptions(d_model, nhead)) {} + explicit TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_); + + Tensor forward( + const Tensor& src, + const Tensor& src_mask = {}, + const Tensor& src_key_padding_mask = {}); + + void reset() override; + + void reset_parameters(); + + protected: + FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) + + public: + /// options with which this `TransformerEncoderLayer` was constructed + TransformerEncoderLayerOptions options; + + /// self attention + MultiheadAttention self_attn = nullptr; + + /// feedforward first linear layer + Linear linear1 = nullptr; + + /// feedforward dropout layer + Dropout dropout = nullptr; + + /// feedforward second linear layer + Linear linear2 = nullptr; + + /// pre feedforward, normalization layer + LayerNorm norm1 = nullptr; + /// post feedfastward, normalization layer + LayerNorm norm2 = nullptr; + + /// pre feedfastward, dropout layer + Dropout dropout1 = nullptr; + /// post feedfastward, dropout layer + Dropout dropout2 = nullptr; +}; + +/// A `ModuleHolder` subclass for `TransformerEncoderLayerImpl``. +/// See the documentation for `TransformerEncoderLayerImpl` class to learn what +/// methods it provides, and examples of how to use `TransformerEncoderLayer` +/// with `torch::nn::TransformerEncoderLayerOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(TransformerEncoderLayer); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoderLayer +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// TransformerDecoderLayer is made up of self-attn, multi-head-attn and +/// feedforward network. This standard decoder layer is based on the paper +/// "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, +/// Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia +/// Polosukhin. 2017. Attention is all you need. In Advances in Neural +/// Information Processing Systems, pages 6000-6010. Users may modify or +/// implement in a different way during application. See +/// https://pytorch.org/docs/main/nn.html#transformer-layers to learn about +/// the exact behavior of this module. +/// +/// See the documentation for `torch::nn::TransformerDecoderLayerOptions` class +/// to learn what constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512, +/// 8).dropout(0.2)); +/// ``` +class TORCH_API TransformerDecoderLayerImpl + : public Cloneable { + public: + TransformerDecoderLayerImpl(int64_t d_model, int64_t nhead) + : TransformerDecoderLayerImpl( + TransformerDecoderLayerOptions(d_model, nhead)) {} + explicit TransformerDecoderLayerImpl(TransformerDecoderLayerOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pass the inputs (and mask) through the decoder layer. + /// Args: + /// tgt: the sequence to the decoder layer (required). + /// memory: the sequence from the last layer of the encoder (required). + /// tgt_mask: the mask for the tgt sequence (optional). + /// memory_mask: the mask for the memory sequence (optional). + /// tgt_key_padding_mask: the mask for the tgt keys per batch + /// (optional). memory_key_padding_mask: the mask for the memory keys + /// per batch (optional). + Tensor forward( + Tensor tgt, + const Tensor& memory, + const Tensor& tgt_mask = {}, + const Tensor& memory_mask = {}, + const Tensor& tgt_key_padding_mask = {}, + const Tensor& memory_key_padding_mask = {}); + + /// The options used to configure this module. + TransformerDecoderLayerOptions options; + + /// self attention + MultiheadAttention self_attn{nullptr}; + + /// Dropout, post self attention + Dropout dropout1{nullptr}; + + /// Normalization, post self attention + LayerNorm norm1{nullptr}; + + /// Multi-headed attention + MultiheadAttention multihead_attn{nullptr}; + + /// Dropout, post multi-headed attention + Dropout dropout2{nullptr}; + + /// Normalization, post multi-headed attention + LayerNorm norm2{nullptr}; + + /// Feed forward first linear layer + Linear linear1{nullptr}; + + /// Feed forward dropout layer + Dropout dropout{nullptr}; + + /// Feed forward second linear layer + Linear linear2{nullptr}; + + /// Dropout, post feed forward + Dropout dropout3{nullptr}; + + /// Normalization, post feed forward + LayerNorm norm3{nullptr}; + + protected: + FORWARD_HAS_DEFAULT_ARGS( + {2, AnyValue(Tensor())}, + {3, AnyValue(Tensor())}, + {4, AnyValue(Tensor())}, + {5, AnyValue(Tensor())}) + + /// Apply activation based on configuration + Tensor activation(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `TransformerDecoderLayerImpl`. +/// See the documentation for `TransformerDecoderLayerImpl` class to learn what +/// methods it provides, and examples of how to use `TransformerDecoderLayer` +/// with `torch::nn::TransformerDecoderLayerOptions`. See the documentation for +/// `ModuleHolder` to learn about PyTorch's module storage semantics. +TORCH_MODULE(TransformerDecoderLayer); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/upsampling.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/upsampling.h new file mode 100644 index 0000000000000000000000000000000000000000..e02658a6af4e1b4ef5c8aa35ab09753245813d83 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/upsampling.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch::nn { + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Upsample ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D +/// (volumetric) data. +/// See https://pytorch.org/docs/main/nn.html#torch.nn.Upsample to learn +/// about the exact behavior of this module. +/// +/// See the documentation for `torch::nn::UpsampleOptions` class to learn what +/// constructor arguments are supported for this module. +/// +/// Example: +/// ``` +/// Upsample +/// model(UpsampleOptions().scale_factor({3}).mode(torch::kLinear).align_corners(false)); +/// ``` +class TORCH_API UpsampleImpl : public Cloneable { + public: + explicit UpsampleImpl(UpsampleOptions options_ = {}); + + void reset() override; + + /// Pretty prints the `Upsample` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input); + + /// The options with which this `Module` was constructed. + UpsampleOptions options; +}; + +/// A `ModuleHolder` subclass for `UpsampleImpl`. +/// See the documentation for `UpsampleImpl` class to learn what methods it +/// provides, and examples of how to use `Upsample` with +/// `torch::nn::UpsampleOptions`. See the documentation for `ModuleHolder` to +/// learn about PyTorch's module storage semantics. +TORCH_MODULE(Upsample); + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/utils.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b89abb748c89870320c9a6b6ccc570dd2b3927fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/modules/utils.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch::nn::modules::utils { + +// Reverse the order of `t` and repeat each element for `n` times. +// This can be used to translate padding arg used by Conv and Pooling modules +// to the ones used by `F::pad`. +// +// This mirrors `_reverse_repeat_tuple` in `torch/nn/modules/utils.py`. +inline std::vector _reverse_repeat_vector( + c10::ArrayRef t, + int64_t n) { + TORCH_INTERNAL_ASSERT(n >= 0); + std::vector ret; + ret.reserve(t.size() * n); + for (auto rit = t.rbegin(); rit != t.rend(); ++rit) { + for ([[maybe_unused]] const auto i : c10::irange(n)) { + ret.emplace_back(*rit); + } + } + return ret; +} + +inline std::vector _list_with_default( + c10::ArrayRef> out_size, + c10::IntArrayRef defaults) { + TORCH_CHECK( + defaults.size() > out_size.size(), + "Input dimension should be at least ", + out_size.size() + 1); + std::vector ret; + c10::IntArrayRef defaults_slice = + defaults.slice(defaults.size() - out_size.size(), out_size.size()); + for (const auto i : c10::irange(out_size.size())) { + auto v = out_size.at(i); + auto d = defaults_slice.at(i); + ret.emplace_back(v.has_value() ? v.value() : d); + } + return ret; +} + +} // namespace torch::nn::modules::utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options.h new file mode 100644 index 0000000000000000000000000000000000000000..4a5224a478e1b650e393dcb3f95adc13ab36d65f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/activation.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/activation.h new file mode 100644 index 0000000000000000000000000000000000000000..480e09ad4de2b504407522accc47b3df5b745049 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/activation.h @@ -0,0 +1,712 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for the `ELU` module. +/// +/// Example: +/// ``` +/// ELU model(ELUOptions().alpha(42.42).inplace(true)); +/// ``` +struct TORCH_API ELUOptions { + /// The `alpha` value for the ELU formulation. Default: 1.0 + TORCH_ARG(double, alpha) = 1.0; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::elu`. +/// +/// See the documentation for `torch::nn::ELUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::elu(x, F::ELUFuncOptions().alpha(0.42).inplace(true)); +/// ``` +using ELUFuncOptions = ELUOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `SELU` module. +/// +/// Example: +/// ``` +/// SELU model(SELUOptions().inplace(true)); +/// ``` +struct TORCH_API SELUOptions { + /* implicit */ SELUOptions(bool inplace = false); + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace); +}; + +namespace functional { +/// Options for `torch::nn::functional::selu`. +/// +/// See the documentation for `torch::nn::SELUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::selu(input, F::SELUFuncOptions(false)); +/// ``` +using SELUFuncOptions = SELUOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `GLU` module. +/// +/// Example: +/// ``` +/// GLU model(GLUOptions(1)); +/// ``` +struct TORCH_API GLUOptions { + /* implicit */ GLUOptions(int64_t dim = -1); + + /// the dimension on which to split the input. Default: -1 + TORCH_ARG(int64_t, dim); +}; + +namespace functional { +/// Options for `torch::nn::functional::glu`. +/// +/// See the documentation for `torch::nn::GLUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::glu(input, GLUFuncOptions(1)); +/// ``` +using GLUFuncOptions = GLUOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `GELU` module. +/// +/// Example: +/// ``` +/// GELU model(GELUOptions().approximate("none")); +/// ``` +struct TORCH_API GELUOptions { + /// Specifies the approximation to apply to the output. + TORCH_ARG(std::string, approximate) = "none"; +}; + +namespace functional { +/// Options for `torch::nn::functional::gelu`. +/// +/// See the documentation for `torch::nn::GELUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::gelu(input, F::GELUFuncOptions().approximate("none")); +/// ``` +using GELUFuncOptions = GELUOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `Hardshrink` module. +/// +/// Example: +/// ``` +/// Hardshrink model(HardshrinkOptions().lambda(42.42)); +/// ``` +struct TORCH_API HardshrinkOptions { + /* implicit */ HardshrinkOptions(double lambda = 0.5); + + /// the `lambda` value for the Hardshrink formulation. Default: 0.5 + TORCH_ARG(double, lambda); +}; + +namespace functional { +/// Options for `torch::nn::functional::hardshrink`. +/// +/// See the documentation for `torch::nn::HardshrinkOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::hardshrink(x, F::HardshrinkFuncOptions().lambda(0.42)); +/// ``` +using HardshrinkFuncOptions = HardshrinkOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `Hardtanh` module. +/// +/// Example: +/// ``` +/// Hardtanh +/// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); +/// ``` +struct TORCH_API HardtanhOptions { + /// minimum value of the linear region range. Default: -1 + TORCH_ARG(double, min_val) = -1.0; + + /// maximum value of the linear region range. Default: 1 + TORCH_ARG(double, max_val) = 1.0; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::hardtanh`. +/// +/// See the documentation for `torch::nn::HardtanhOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::hardtanh(x, +/// F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true)); +/// ``` +using HardtanhFuncOptions = HardtanhOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `LeakyReLU` module. +/// +/// Example: +/// ``` +/// LeakyReLU model(LeakyReLUOptions().negative_slope(0.42).inplace(true)); +/// ``` +struct TORCH_API LeakyReLUOptions { + /// Controls the angle of the negative slope. Default: 1e-2 + TORCH_ARG(double, negative_slope) = 1e-2; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::leaky_relu`. +/// +/// See the documentation for `torch::nn::LeakyReLUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::leaky_relu(x, +/// F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true)); +/// ``` +using LeakyReLUFuncOptions = LeakyReLUOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `Softmax` module. +/// +/// Example: +/// ``` +/// Softmax model(SoftmaxOptions(1)); +/// ``` +struct TORCH_API SoftmaxOptions { + SoftmaxOptions(int64_t dim); + + /// Dimension along which Softmax will be computed. + TORCH_ARG(int64_t, dim); +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::softmax`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softmax(input, F::SoftmaxFuncOptions(1)); +/// ``` +struct TORCH_API SoftmaxFuncOptions { + SoftmaxFuncOptions(int64_t dim); + + /// Dimension along which Softmax will be computed. + TORCH_ARG(int64_t, dim); + + /// the desired data type of returned tensor. + /// If specified, the input tensor is casted to `dtype` before the operation + /// is performed. This is useful for preventing data type overflows. Default: + /// None. + TORCH_ARG(std::optional, dtype) = std::nullopt; +}; + +} // namespace functional + +// ============================================================================ + +/// Options for the `Softmin` module. +/// +/// Example: +/// ``` +/// Softmin model(SoftminOptions(1)); +/// ``` +struct TORCH_API SoftminOptions { + SoftminOptions(int64_t dim); + + /// Dimension along which Softmin will be computed. + TORCH_ARG(int64_t, dim); +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::softmin`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softmin(input, F::SoftminFuncOptions(1)); +/// ``` +struct TORCH_API SoftminFuncOptions { + SoftminFuncOptions(int64_t dim); + + /// Dimension along which Softmin will be computed. + TORCH_ARG(int64_t, dim); + + /// the desired data type of returned tensor. + /// If specified, the input tensor is casted to `dtype` before the operation + /// is performed. This is useful for preventing data type overflows. Default: + /// None. + TORCH_ARG(std::optional, dtype) = std::nullopt; +}; + +} // namespace functional + +// ============================================================================ + +/// Options for the `LogSoftmax` module. +/// +/// Example: +/// ``` +/// LogSoftmax model(LogSoftmaxOptions(1)); +/// ``` +struct TORCH_API LogSoftmaxOptions { + LogSoftmaxOptions(int64_t dim); + + /// Dimension along which LogSoftmax will be computed. + TORCH_ARG(int64_t, dim); +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::log_softmax`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::log_softmax(input, LogSoftmaxFuncOptions(1)); +/// ``` +struct TORCH_API LogSoftmaxFuncOptions { + LogSoftmaxFuncOptions(int64_t dim); + + /// Dimension along which LogSoftmax will be computed. + TORCH_ARG(int64_t, dim); + + /// the desired data type of returned tensor. + /// If specified, the input tensor is casted to `dtype` before the operation + /// is performed. This is useful for preventing data type overflows. Default: + /// None. + TORCH_ARG(std::optional, dtype) = std::nullopt; +}; + +} // namespace functional + +// ============================================================================ + +/// Options for the `PReLU` module. +/// +/// Example: +/// ``` +/// PReLU model(PReLUOptions().num_parameters(42)); +/// ``` +struct TORCH_API PReLUOptions { + /// number of `a` to learn. Although it takes an int as input, there is only + /// two values are legitimate: 1, or the number of channels at input. Default: + /// 1 + TORCH_ARG(int64_t, num_parameters) = 1; + + /// the initial value of `a`. Default: 0.25 + TORCH_ARG(double, init) = 0.25; +}; + +// ============================================================================ + +/// Options for the `ReLU` module. +/// +/// Example: +/// ``` +/// ReLU model(ReLUOptions().inplace(true)); +/// ``` +struct TORCH_API ReLUOptions { + /* implicit */ ReLUOptions(bool inplace = false); + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace); +}; + +namespace functional { +/// Options for `torch::nn::functional::relu`. +/// +/// See the documentation for `torch::nn::ReLUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::relu(x, F::ReLUFuncOptions().inplace(true)); +/// ``` +using ReLUFuncOptions = ReLUOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `ReLU6` module. +/// +/// Example: +/// ``` +/// ReLU6 model(ReLU6Options().inplace(true)); +/// ``` +struct TORCH_API ReLU6Options { + /* implicit */ ReLU6Options(bool inplace = false); + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace); +}; + +namespace functional { +/// Options for `torch::nn::functional::relu6`. +/// +/// See the documentation for `torch::nn::ReLU6Options` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::relu6(x, F::ReLU6FuncOptions().inplace(true)); +/// ``` +using ReLU6FuncOptions = ReLU6Options; +} // namespace functional + +// ============================================================================ + +/// Options for the `RReLU` module. +/// +/// Example: +/// ``` +/// RReLU model(RReLUOptions().lower(0.24).upper(0.42).inplace(true)); +/// ``` +struct TORCH_API RReLUOptions { + /// lower bound of the uniform distribution. Default: 1/8 + TORCH_ARG(double, lower) = 1.0 / 8.0; + + /// upper bound of the uniform distribution. Default: 1/3 + TORCH_ARG(double, upper) = 1.0 / 3.0; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::rrelu`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::rrelu(x, F::RReLUFuncOptions().lower(0.1).upper(0.4).inplace(true)); +/// ``` +struct TORCH_API RReLUFuncOptions { + /// lower bound of the uniform distribution. Default: 1/8 + TORCH_ARG(double, lower) = 1.0 / 8.0; + + /// upper bound of the uniform distribution. Default: 1/3 + TORCH_ARG(double, upper) = 1.0 / 3.0; + + TORCH_ARG(bool, training) = false; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +} // namespace functional + +// ============================================================================ + +/// Options for the `CELU` module. +/// +/// Example: +/// ``` +/// CELU model(CELUOptions().alpha(42.42).inplace(true)); +/// ``` +struct TORCH_API CELUOptions { + /// The `alpha` value for the CELU formulation. Default: 1.0 + TORCH_ARG(double, alpha) = 1.0; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::celu`. +/// +/// See the documentation for `torch::nn::CELUOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::celu(x, F::CELUFuncOptions().alpha(0.42).inplace(true)); +/// ``` +using CELUFuncOptions = CELUOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `Softplus` module. +/// +/// Example: +/// ``` +/// Softplus model(SoftplusOptions().beta(0.24).threshold(42.42)); +/// ``` +struct TORCH_API SoftplusOptions { + /// the `beta` value for the Softplus formulation. Default: 1 + TORCH_ARG(double, beta) = 1.0; + + /// values above this revert to a linear function. Default: 20 + TORCH_ARG(double, threshold) = 20.0; +}; + +namespace functional { +/// Options for `torch::nn::functional::softplus`. +/// +/// See the documentation for `torch::nn::SoftplusOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softplus(x, F::SoftplusFuncOptions().beta(0.5).threshold(3.0)); +/// ``` +using SoftplusFuncOptions = SoftplusOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `Softshrink` module. +/// +/// Example: +/// ``` +/// Softshrink model(SoftshrinkOptions(42.42)); +/// ``` +struct TORCH_API SoftshrinkOptions { + /* implicit */ SoftshrinkOptions(double lambda = 0.5); + + /// the `lambda` value for the Softshrink formulation. Default: 0.5 + TORCH_ARG(double, lambda); +}; + +namespace functional { +/// Options for `torch::nn::functional::softshrink`. +/// +/// See the documentation for `torch::nn::SoftshrinkOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::softshrink(x, F::SoftshrinkFuncOptions(0.42)); +/// ``` +using SoftshrinkFuncOptions = SoftshrinkOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `Threshold` module. +/// +/// Example: +/// ``` +/// Threshold model(ThresholdOptions(42.42, 24.24).inplace(true)); +/// ``` +struct TORCH_API ThresholdOptions { + ThresholdOptions(double threshold, double value) + : threshold_(threshold), value_(value) {} + + /// The value to threshold at + TORCH_ARG(double, threshold); + + /// The value to replace with + TORCH_ARG(double, value); + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::threshold`. +/// +/// See the documentation for `torch::nn::ThresholdOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::threshold(x, F::ThresholdFuncOptions(0.5, 0.5).inplace(true)); +/// ``` +using ThresholdFuncOptions = ThresholdOptions; +} // namespace functional + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::gumbel_softmax`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(-1)); +/// ``` +struct TORCH_API GumbelSoftmaxFuncOptions { + /// non-negative scalar temperature + TORCH_ARG(double, tau) = 1.0; + + /// returned samples will be discretized as one-hot vectors, + /// but will be differentiated as if it is the soft sample in autograd. + /// Default: False + TORCH_ARG(bool, hard) = false; + + /// dimension along which softmax will be computed. Default: -1 + TORCH_ARG(int, dim) = -1; +}; + +} // namespace functional + +// ============================================================================ + +/// Options for the `MultiheadAttention` module. +/// +/// Example: +/// ``` +/// MultiheadAttention model(MultiheadAttentionOptions(20, 10).bias(false)); +/// ``` +struct TORCH_API MultiheadAttentionOptions { + MultiheadAttentionOptions(int64_t embed_dim, int64_t num_heads); + + /// total dimension of the model. + TORCH_ARG(int64_t, embed_dim); + + /// parallel attention heads. + TORCH_ARG(int64_t, num_heads); + + /// a Dropout layer on attn_output_weights. Default: 0.0. + TORCH_ARG(double, dropout) = 0.0; + + /// add bias as module parameter. Default: true. + TORCH_ARG(bool, bias) = true; + + /// add bias to the key and value sequences at dim=0. + TORCH_ARG(bool, add_bias_kv) = false; + + /// add a new batch of zeros to the key and value sequences at dim=1. + TORCH_ARG(bool, add_zero_attn) = false; + + /// total number of features in key. Default: std::nullopt. + TORCH_ARG(int64_t, kdim); + + /// total number of features in key. Default: std::nullopt. + TORCH_ARG(int64_t, vdim); +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::multi_head_attention_forward` +struct TORCH_API MultiheadAttentionForwardFuncOptions { + MultiheadAttentionForwardFuncOptions( + int64_t embed_dim_to_check, + int64_t num_heads, + Tensor in_proj_weight, + Tensor in_proj_bias, + Tensor bias_k, + Tensor bias_v, + bool add_zero_attn, + double dropout_p, + Tensor out_proj_weight, + Tensor out_proj_bias); + + TORCH_ARG(int64_t, embed_dim_to_check); + + TORCH_ARG(int64_t, num_heads); + + TORCH_ARG(Tensor, in_proj_weight); + + TORCH_ARG(Tensor, in_proj_bias); + + TORCH_ARG(Tensor, bias_k); + + TORCH_ARG(Tensor, bias_v); + + TORCH_ARG(bool, add_zero_attn); + + TORCH_ARG(double, dropout_p); + + TORCH_ARG(Tensor, out_proj_weight); + + TORCH_ARG(Tensor, out_proj_bias); + + TORCH_ARG(bool, training) = true; + + TORCH_ARG(Tensor, key_padding_mask) = {}; + + TORCH_ARG(bool, need_weights) = true; + + TORCH_ARG(Tensor, attn_mask) = {}; + + TORCH_ARG(bool, use_separate_proj_weight) = false; + + TORCH_ARG(Tensor, q_proj_weight) = {}; + + TORCH_ARG(Tensor, k_proj_weight) = {}; + + TORCH_ARG(Tensor, v_proj_weight) = {}; + + TORCH_ARG(Tensor, static_k) = {}; + + TORCH_ARG(Tensor, static_v) = {}; + + TORCH_ARG(bool, average_attn_weights) = true; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/adaptive.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/adaptive.h new file mode 100644 index 0000000000000000000000000000000000000000..4335fb725c6f403ba3e95b583c692e503df069d5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/adaptive.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn { + +/// Options for the `AdaptiveLogSoftmaxWithLoss` module. +/// +/// Example: +/// ``` +/// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, +/// {4, 8}).div_value(2.).head_bias(true)); +/// ``` +struct TORCH_API AdaptiveLogSoftmaxWithLossOptions { + /* implicit */ AdaptiveLogSoftmaxWithLossOptions( + int64_t in_features, + int64_t n_classes, + std::vector cutoffs); + + /// Number of features in the input tensor + TORCH_ARG(int64_t, in_features); + + /// Number of classes in the dataset + TORCH_ARG(int64_t, n_classes); + + /// Cutoffs used to assign targets to their buckets + TORCH_ARG(std::vector, cutoffs); + + /// value used as an exponent to compute sizes of the clusters. Default: 4.0 + TORCH_ARG(double, div_value) = 4.; + + /// If ``true``, adds a bias term to the 'head' of + /// the adaptive softmax. Default: false + TORCH_ARG(bool, head_bias) = false; +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/batchnorm.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/batchnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..d77cfb4f0d1529ab3208dd75d55896a01b902767 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/batchnorm.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn { + +/// Options for the `BatchNorm` module. +struct TORCH_API BatchNormOptions { + /* implicit */ BatchNormOptions(int64_t num_features); + + /// The number of features of the input tensor. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, num_features); + + /// The epsilon value added for numerical stability. + /// Changing this parameter after construction __is effective__. + TORCH_ARG(double, eps) = 1e-5; + + /// A momentum multiplier for the mean and variance. + /// Changing this parameter after construction __is effective__. + TORCH_ARG(std::optional, momentum) = 0.1; + + /// Whether to learn a scale and bias that are applied in an affine + /// transformation on the input. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, affine) = true; + + /// Whether to store and update batch statistics (mean and variance) in the + /// module. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, track_running_stats) = true; +}; + +/// Options for the `BatchNorm1d` module. +/// +/// Example: +/// ``` +/// BatchNorm1d +/// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +using BatchNorm1dOptions = BatchNormOptions; + +/// Options for the `BatchNorm2d` module. +/// +/// Example: +/// ``` +/// BatchNorm2d +/// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +using BatchNorm2dOptions = BatchNormOptions; + +/// Options for the `BatchNorm3d` module. +/// +/// Example: +/// ``` +/// BatchNorm3d +/// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +using BatchNorm3dOptions = BatchNormOptions; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::batch_norm`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::batch_norm(input, mean, variance, +/// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false)); +/// ``` +struct TORCH_API BatchNormFuncOptions { + TORCH_ARG(Tensor, weight) = Tensor(); + + TORCH_ARG(Tensor, bias) = Tensor(); + + TORCH_ARG(bool, training) = false; + + /// A momentum multiplier for the mean and variance. + /// Changing this parameter after construction __is effective__. + TORCH_ARG(double, momentum) = 0.1; + + /// The epsilon value added for numerical stability. + /// Changing this parameter after construction __is effective__. + TORCH_ARG(double, eps) = 1e-5; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/conv.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/conv.h new file mode 100644 index 0000000000000000000000000000000000000000..f10d5e9a3106182efea8888b4bc6609854ae8705 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/conv.h @@ -0,0 +1,413 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::nn { + +namespace detail { + +typedef std::variant< + enumtype::kZeros, + enumtype::kReflect, + enumtype::kReplicate, + enumtype::kCircular> + conv_padding_mode_t; + +template +using conv_padding_t = + std::variant, enumtype::kValid, enumtype::kSame>; + +/// Options for a `D`-dimensional convolution or convolution transpose module. +template +struct ConvNdOptions { + using padding_t = conv_padding_t; + ConvNdOptions( + int64_t in_channels, + int64_t out_channels, + ExpandingArray kernel_size) + : in_channels_(in_channels), + out_channels_(out_channels), + kernel_size_(std::move(kernel_size)) {} + + /// The number of channels the input volumes will have. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, in_channels); + + /// The number of output channels the convolution should produce. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, out_channels); + + /// The kernel size to use. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, kernel_size); + + /// The stride of the convolution. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, stride) = 1; + + /// The padding to add to the input volumes. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(padding_t, padding) = 0; + + public: + decltype(auto) padding(std::initializer_list il) { + return padding(IntArrayRef{il}); + } + + /// The kernel dilation. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, dilation) = 1; + + /// If true, convolutions will be transpose convolutions (a.k.a. + /// deconvolutions). + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, transposed) = false; + + /// For transpose convolutions, the padding to add to output volumes. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, output_padding) = 0; + + /// The number of convolution groups. + /// This parameter __can__ be changed after construction. + TORCH_ARG(int64_t, groups) = 1; + + /// Whether to add a bias after individual applications of the kernel. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, bias) = true; + + /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or + /// `torch::kCircular`. Default: `torch::kZeros` + TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros; +}; + +} // namespace detail + +// ============================================================================ + +/// Options for a `D`-dimensional convolution module. +template +struct ConvOptions { + using padding_mode_t = detail::conv_padding_mode_t; + using padding_t = detail::conv_padding_t; + + ConvOptions( + int64_t in_channels, + int64_t out_channels, + ExpandingArray kernel_size) + : in_channels_(in_channels), + out_channels_(out_channels), + kernel_size_(std::move(kernel_size)) {} + + /// The number of channels the input volumes will have. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, in_channels); + + /// The number of output channels the convolution should produce. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, out_channels); + + /// The kernel size to use. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, kernel_size); + + /// The stride of the convolution. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, stride) = 1; + + /// The padding to add to the input volumes. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(padding_t, padding) = 0; + + public: + decltype(auto) padding(std::initializer_list il) { + return padding(IntArrayRef{il}); + } + + /// The kernel dilation. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, dilation) = 1; + + /// The number of convolution groups. + /// This parameter __can__ be changed after construction. + TORCH_ARG(int64_t, groups) = 1; + + /// Whether to add a bias after individual applications of the kernel. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, bias) = true; + + /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or + /// `torch::kCircular`. Default: `torch::kZeros` + TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; +}; + +/// `ConvOptions` specialized for the `Conv1d` module. +/// +/// Example: +/// ``` +/// Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false)); +/// ``` +using Conv1dOptions = ConvOptions<1>; + +/// `ConvOptions` specialized for the `Conv2d` module. +/// +/// Example: +/// ``` +/// Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false)); +/// ``` +using Conv2dOptions = ConvOptions<2>; + +/// `ConvOptions` specialized for the `Conv3d` module. +/// +/// Example: +/// ``` +/// Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false)); +/// ``` +using Conv3dOptions = ConvOptions<3>; + +// ============================================================================ + +namespace functional { + +/// Options for a `D`-dimensional convolution functional. +template +struct ConvFuncOptions { + using padding_t = torch::nn::detail::conv_padding_t; + + /// optional bias of shape `(out_channels)`. Default: ``None`` + TORCH_ARG(torch::Tensor, bias) = Tensor(); + + /// The stride of the convolving kernel. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + TORCH_ARG(ExpandingArray, stride) = 1; + + /// Implicit paddings on both sides of the input. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + TORCH_ARG(padding_t, padding) = 0; + + public: + decltype(auto) padding(std::initializer_list il) { + return padding(IntArrayRef{il}); + } + + /// The spacing between kernel elements. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + TORCH_ARG(ExpandingArray, dilation) = 1; + + /// Split input into groups, `in_channels` should be divisible by + /// the number of groups. + TORCH_ARG(int64_t, groups) = 1; +}; + +/// `ConvFuncOptions` specialized for `torch::nn::functional::conv1d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1)); +/// ``` +using Conv1dFuncOptions = ConvFuncOptions<1>; + +/// `ConvFuncOptions` specialized for `torch::nn::functional::conv2d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1)); +/// ``` +using Conv2dFuncOptions = ConvFuncOptions<2>; + +/// `ConvFuncOptions` specialized for `torch::nn::functional::conv3d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1)); +/// ``` +using Conv3dFuncOptions = ConvFuncOptions<3>; + +} // namespace functional + +// ============================================================================ + +template +struct ConvTransposeOptions { + using padding_mode_t = detail::conv_padding_mode_t; + + ConvTransposeOptions( + int64_t in_channels, + int64_t out_channels, + ExpandingArray kernel_size) + : in_channels_(in_channels), + out_channels_(out_channels), + kernel_size_(std::move(kernel_size)) {} + + /// The number of channels the input volumes will have. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, in_channels); + + /// The number of output channels the convolution should produce. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, out_channels); + + /// The kernel size to use. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, kernel_size); + + /// The stride of the convolution. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, stride) = 1; + + /// The padding to add to the input volumes. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, padding) = 0; + + /// For transpose convolutions, the padding to add to output volumes. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, output_padding) = 0; + + /// The number of convolution groups. + /// This parameter __can__ be changed after construction. + TORCH_ARG(int64_t, groups) = 1; + + /// Whether to add a bias after individual applications of the kernel. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, bias) = true; + + /// The kernel dilation. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, dilation) = 1; + + /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or + /// `torch::kCircular`. Default: `torch::kZeros` + TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; +}; + +/// `ConvTransposeOptions` specialized for the `ConvTranspose1d` module. +/// +/// Example: +/// ``` +/// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, +/// 3).stride(1).bias(false)); +/// ``` +using ConvTranspose1dOptions = ConvTransposeOptions<1>; + +/// `ConvTransposeOptions` specialized for the `ConvTranspose2d` module. +/// +/// Example: +/// ``` +/// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, +/// 3).stride(1).bias(false)); +/// ``` +using ConvTranspose2dOptions = ConvTransposeOptions<2>; + +/// `ConvTransposeOptions` specialized for the `ConvTranspose3d` module. +/// +/// Example: +/// ``` +/// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, +/// 2).stride(1).bias(false)); +/// ``` +using ConvTranspose3dOptions = ConvTransposeOptions<3>; + +// ============================================================================ + +namespace functional { + +/// Options for a `D`-dimensional convolution functional. +template +struct ConvTransposeFuncOptions { + /// optional bias of shape `(out_channels)`. Default: ``None`` + TORCH_ARG(torch::Tensor, bias) = Tensor(); + + /// The stride of the convolving kernel. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + TORCH_ARG(ExpandingArray, stride) = 1; + + /// Implicit paddings on both sides of the input. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + TORCH_ARG(ExpandingArray, padding) = 0; + + /// Additional size added to one side of each dimension in the output shape. + /// Default: 0 + TORCH_ARG(ExpandingArray, output_padding) = 0; + + /// Split input into groups, `in_channels` should be divisible by + /// the number of groups. + TORCH_ARG(int64_t, groups) = 1; + + /// The spacing between kernel elements. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + TORCH_ARG(ExpandingArray, dilation) = 1; +}; + +/// `ConvTransposeFuncOptions` specialized for +/// `torch::nn::functional::conv_transpose1d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1)); +/// ``` +using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>; + +/// `ConvTransposeFuncOptions` specialized for +/// `torch::nn::functional::conv_transpose2d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1)); +/// ``` +using ConvTranspose2dFuncOptions = ConvTransposeFuncOptions<2>; + +/// `ConvTransposeFuncOptions` specialized for +/// `torch::nn::functional::conv_transpose3d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1)); +/// ``` +using ConvTranspose3dFuncOptions = ConvTransposeFuncOptions<3>; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/distance.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/distance.h new file mode 100644 index 0000000000000000000000000000000000000000..c9cfc2e0aae2f4c14259c7ae7e9a36943b19d7c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/distance.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn { + +/// Options for the `CosineSimilarity` module. +/// +/// Example: +/// ``` +/// CosineSimilarity model(CosineSimilarityOptions().dim(0).eps(0.5)); +/// ``` +struct TORCH_API CosineSimilarityOptions { + /// Dimension where cosine similarity is computed. Default: 1 + TORCH_ARG(int64_t, dim) = 1; + /// Small value to avoid division by zero. Default: 1e-8 + TORCH_ARG(double, eps) = 1e-8; +}; + +namespace functional { +/// Options for `torch::nn::functional::cosine_similarity`. +/// +/// See the documentation for `torch::nn::CosineSimilarityOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::cosine_similarity(input1, input2, +/// F::CosineSimilarityFuncOptions().dim(1)); +/// ``` +using CosineSimilarityFuncOptions = CosineSimilarityOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `PairwiseDistance` module. +/// +/// Example: +/// ``` +/// PairwiseDistance +/// model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true)); +/// ``` +struct TORCH_API PairwiseDistanceOptions { + /// The norm degree. Default: 2 + TORCH_ARG(double, p) = 2.0; + /// Small value to avoid division by zero. Default: 1e-6 + TORCH_ARG(double, eps) = 1e-6; + /// Determines whether or not to keep the vector dimension. Default: false + TORCH_ARG(bool, keepdim) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::pairwise_distance`. +/// +/// See the documentation for `torch::nn::PairwiseDistanceOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pairwise_distance(input1, input2, F::PairwiseDistanceFuncOptions().p(1)); +/// ``` +using PairwiseDistanceFuncOptions = PairwiseDistanceOptions; +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/dropout.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/dropout.h new file mode 100644 index 0000000000000000000000000000000000000000..865920c599cc3ae32dd55c636507e94eeb4925bb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/dropout.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn { + +/// Options for the `Dropout` module. +/// +/// Example: +/// ``` +/// Dropout model(DropoutOptions().p(0.42).inplace(true)); +/// ``` +struct TORCH_API DropoutOptions { + /* implicit */ DropoutOptions(double p = 0.5); + + /// The probability of an element to be zeroed. Default: 0.5 + TORCH_ARG(double, p) = 0.5; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +/// Options for the `Dropout2d` module. +/// +/// Example: +/// ``` +/// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true)); +/// ``` +using Dropout2dOptions = DropoutOptions; + +/// Options for the `Dropout3d` module. +/// +/// Example: +/// ``` +/// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true)); +/// ``` +using Dropout3dOptions = DropoutOptions; + +/// Options for the `AlphaDropout` module. +/// +/// Example: +/// ``` +/// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true)); +/// ``` +using AlphaDropoutOptions = DropoutOptions; + +/// Options for the `FeatureAlphaDropout` module. +/// +/// Example: +/// ``` +/// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true)); +/// ``` +using FeatureAlphaDropoutOptions = DropoutOptions; + +namespace functional { + +/// Options for `torch::nn::functional::dropout`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::dropout(input, F::DropoutFuncOptions().p(0.5)); +/// ``` +struct TORCH_API DropoutFuncOptions { + /// The probability of an element to be zeroed. Default: 0.5 + TORCH_ARG(double, p) = 0.5; + + TORCH_ARG(bool, training) = true; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +/// Options for `torch::nn::functional::dropout2d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::dropout2d(input, F::Dropout2dFuncOptions().p(0.5)); +/// ``` +using Dropout2dFuncOptions = DropoutFuncOptions; + +/// Options for `torch::nn::functional::dropout3d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::dropout3d(input, F::Dropout3dFuncOptions().p(0.5)); +/// ``` +using Dropout3dFuncOptions = DropoutFuncOptions; + +/// Options for `torch::nn::functional::alpha_dropout`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::alpha_dropout(input, +/// F::AlphaDropoutFuncOptions().p(0.5).training(false)); +/// ``` +struct TORCH_API AlphaDropoutFuncOptions { + TORCH_ARG(double, p) = 0.5; + + TORCH_ARG(bool, training) = false; + + TORCH_ARG(bool, inplace) = false; +}; + +/// Options for `torch::nn::functional::feature_alpha_dropout`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::feature_alpha_dropout(input, +/// F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false)); +/// ``` +struct TORCH_API FeatureAlphaDropoutFuncOptions { + TORCH_ARG(double, p) = 0.5; + + TORCH_ARG(bool, training) = false; + + TORCH_ARG(bool, inplace) = false; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/embedding.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/embedding.h new file mode 100644 index 0000000000000000000000000000000000000000..be689f12b3bd979eb7090a8fa1793ebbe28698a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/embedding.h @@ -0,0 +1,240 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for the `Embedding` module. +/// +/// Example: +/// ``` +/// Embedding model(EmbeddingOptions(10, +/// 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// ``` +struct TORCH_API EmbeddingOptions { + EmbeddingOptions(int64_t num_embeddings, int64_t embedding_dim); + + /// The size of the dictionary of embeddings. + TORCH_ARG(int64_t, num_embeddings); + /// The size of each embedding vector. + TORCH_ARG(int64_t, embedding_dim); + /// If specified, the entries at `padding_idx` do not contribute to the + /// gradient; therefore, the embedding vector at `padding_idx` is not updated + /// during training, i.e. it remains as a fixed "pad". For a newly constructed + /// Embedding, the embedding vector at `padding_idx` will default to all + /// zeros, but can be updated to another value to be used as the padding + /// vector. + TORCH_ARG(std::optional, padding_idx) = std::nullopt; + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. + TORCH_ARG(std::optional, max_norm) = std::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(double, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. + TORCH_ARG(bool, sparse) = false; + /// The learnable weights of the module of shape (num_embeddings, + /// embedding_dim) + TORCH_ARG(torch::Tensor, _weight) = Tensor(); +}; + +// ============================================================================ + +/// Options for the `Embedding::from_pretrained` function. +struct TORCH_API EmbeddingFromPretrainedOptions { + /// If ``true``, the tensor does not get updated in the learning process. + /// Equivalent to ``embedding.weight.requires_grad_(false)``. Default: + /// ``true`` + TORCH_ARG(bool, freeze) = true; + /// If specified, the entries at `padding_idx` do not contribute to the + /// gradient; therefore, the embedding vector at `padding_idx` is not updated + /// during training, i.e. it remains as a fixed "pad". + TORCH_ARG(std::optional, padding_idx) = std::nullopt; + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. + TORCH_ARG(std::optional, max_norm) = std::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(double, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. + TORCH_ARG(bool, sparse) = false; +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::embedding`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::embedding(input, weight, +/// F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true)); +/// ``` +struct TORCH_API EmbeddingFuncOptions { + /// If specified, the entries at `padding_idx` do not contribute to the + /// gradient; therefore, the embedding vector at `padding_idx` is not updated + /// during training, i.e. it remains as a fixed "pad". + TORCH_ARG(std::optional, padding_idx) = std::nullopt; + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. + TORCH_ARG(std::optional, max_norm) = std::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(double, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. + TORCH_ARG(bool, sparse) = false; +}; + +} // namespace functional + +// ============================================================================ + +typedef std::variant + EmbeddingBagMode; + +/// Options for the `EmbeddingBag` module. +/// +/// Example: +/// ``` +/// EmbeddingBag model(EmbeddingBagOptions(10, +/// 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum)); +/// ``` +struct TORCH_API EmbeddingBagOptions { + EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim); + + /// The size of the dictionary of embeddings. + TORCH_ARG(int64_t, num_embeddings); + /// The size of each embedding vector. + TORCH_ARG(int64_t, embedding_dim); + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. + TORCH_ARG(std::optional, max_norm) = std::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(double, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. Note: this option is not + /// supported when ``mode="kMax"``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the + /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` + /// into consideration. ``"kMean"`` computes the average of the values in the + /// bag, ``"kMax"`` computes the max value over each bag. + TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean; + /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. + /// Note: this option is not supported when ``mode="kMax"``. + TORCH_ARG(bool, sparse) = false; + /// The learnable weights of the module of shape (num_embeddings, + /// embedding_dim) + TORCH_ARG(torch::Tensor, _weight) = Tensor(); + /// If ``true``, `offsets` has one additional element, where the last element + /// is equivalent to the size of `indices`. This matches the CSR format. + TORCH_ARG(bool, include_last_offset) = false; + /// If specified, the entries at `padding_idx` do not contribute to the + /// gradient; therefore, the embedding vector at padding_idx is not updated + /// during training, i.e. it remains as a fixed "pad". For a newly constructed + /// EmbeddingBag, the embedding vector at `padding_idx` will default to all + /// zeros, but can be updated to another value to be used as the padding + /// vector. Note that the embedding vector at `padding_idx` is excluded from + /// the reduction. + TORCH_ARG(std::optional, padding_idx) = std::nullopt; +}; + +// ============================================================================ + +/// Options for the `EmbeddingBag::from_pretrained` function. +struct TORCH_API EmbeddingBagFromPretrainedOptions { + /// If ``true``, the tensor does not get updated in the learning process. + /// Equivalent to ``embeddingbag.weight.requires_grad_(false)``. Default: + /// ``true`` + TORCH_ARG(bool, freeze) = true; + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. + TORCH_ARG(std::optional, max_norm) = std::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(double, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. Note: this option is not + /// supported when ``mode="kMax"``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the + /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` + /// into consideration. ``"kMean"`` computes the average of the values in the + /// bag, ``"kMax"`` computes the max value over each bag. + TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean; + /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. + /// Note: this option is not supported when ``mode="kMax"``. + TORCH_ARG(bool, sparse) = false; + /// If ``true``, `offsets` has one additional element, where the last element + /// is equivalent to the size of `indices`. This matches the CSR format. Note: + /// this option is currently only supported when ``mode="sum"``. + TORCH_ARG(bool, include_last_offset) = false; + /// If specified, the entries at `padding_idx` do not contribute to the + /// gradient; therefore, the embedding vector at padding_idx is not updated + /// during training, i.e. it remains as a fixed "pad". Note that the embedding + /// vector at `padding_idx` is excluded from the reduction. + TORCH_ARG(std::optional, padding_idx) = std::nullopt; +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::embedding_bag`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::embedding_bag(input, weight, +/// F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets)); +/// ``` +struct TORCH_API EmbeddingBagFuncOptions { + /// Only used when `input` is 1D. `offsets` determines + /// the starting index position of each bag (sequence) in `input`. + TORCH_ARG(torch::Tensor, offsets) = Tensor(); + /// If given, each embedding vector with norm larger than `max_norm` is + /// renormalized to have norm `max_norm`. + TORCH_ARG(std::optional, max_norm) = std::nullopt; + /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. + TORCH_ARG(double, norm_type) = 2.; + /// If given, this will scale gradients by the inverse of frequency of the + /// words in the mini-batch. Default ``false``. Note: this option is not + /// supported when ``mode="kMax"``. + TORCH_ARG(bool, scale_grad_by_freq) = false; + /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the + /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights` + /// into consideration. ``"kMean"`` computes the average of the values in the + /// bag, ``"kMax"`` computes the max value over each bag. + TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean; + /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor. + /// Note: this option is not supported when ``mode="kMax"``. + TORCH_ARG(bool, sparse) = false; + /// a tensor of float / double weights, or None to indicate all weights should + /// be taken to be 1. If specified, `per_sample_weights` must have exactly the + /// same shape as input and is treated as having the same `offsets`, if those + /// are not None. + TORCH_ARG(torch::Tensor, per_sample_weights) = Tensor(); + /// If ``true``, `offsets` has one additional element, where the last element + /// is equivalent to the size of `indices`. This matches the CSR format. Note: + /// this option is currently only supported when ``mode="sum"``. + TORCH_ARG(bool, include_last_offset) = false; + /// If specified, the entries at `padding_idx` do not contribute to the + /// gradient; therefore, the embedding vector at padding_idx is not updated + /// during training, i.e. it remains as a fixed "pad". Note that the embedding + /// vector at `padding_idx` is excluded from the reduction. + TORCH_ARG(std::optional, padding_idx) = std::nullopt; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/fold.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/fold.h new file mode 100644 index 0000000000000000000000000000000000000000..958105e159bb64ec635ec4eda6034b70da157c43 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/fold.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for the `Fold` module. +/// +/// Example: +/// ``` +/// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, +/// 1}).stride(2)); +/// ``` +struct TORCH_API FoldOptions { + FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) + : output_size_(output_size), kernel_size_(kernel_size) {} + + /// describes the spatial shape of the large containing tensor of the sliding + /// local blocks. It is useful to resolve the ambiguity when multiple input + /// shapes map to same number of sliding blocks, e.g., with stride > 0. + TORCH_ARG(ExpandingArray<2>, output_size); + + /// the size of the sliding blocks + TORCH_ARG(ExpandingArray<2>, kernel_size); + + /// controls the spacing between the kernel points; also known as the à trous + /// algorithm. + TORCH_ARG(ExpandingArray<2>, dilation) = 1; + + /// controls the amount of implicit zero-paddings on both sides for padding + /// number of points for each dimension before reshaping. + TORCH_ARG(ExpandingArray<2>, padding) = 0; + + /// controls the stride for the sliding blocks. + TORCH_ARG(ExpandingArray<2>, stride) = 1; +}; + +namespace functional { +/// Options for `torch::nn::functional::fold`. +/// +/// See the documentation for `torch::nn::FoldOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2})); +/// ``` +using FoldFuncOptions = FoldOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `Unfold` module. +/// +/// Example: +/// ``` +/// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2)); +/// ``` +struct TORCH_API UnfoldOptions { + UnfoldOptions(ExpandingArray<2> kernel_size) : kernel_size_(kernel_size) {} + + /// the size of the sliding blocks + TORCH_ARG(ExpandingArray<2>, kernel_size); + + /// controls the spacing between the kernel points; also known as the à trous + /// algorithm. + TORCH_ARG(ExpandingArray<2>, dilation) = 1; + + /// controls the amount of implicit zero-paddings on both sides for padding + /// number of points for each dimension before reshaping. + TORCH_ARG(ExpandingArray<2>, padding) = 0; + + /// controls the stride for the sliding blocks. + TORCH_ARG(ExpandingArray<2>, stride) = 1; +}; + +namespace functional { +/// Options for `torch::nn::functional::unfold`. +/// +/// See the documentation for `torch::nn::UnfoldOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2)); +/// ``` +using UnfoldFuncOptions = UnfoldOptions; +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/instancenorm.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/instancenorm.h new file mode 100644 index 0000000000000000000000000000000000000000..2c90a060340b7d6ae8d46c5a717f9dbe04c9dfc7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/instancenorm.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for the `InstanceNorm` module. +struct TORCH_API InstanceNormOptions { + /* implicit */ InstanceNormOptions(int64_t num_features); + + /// The number of features of the input tensor. + TORCH_ARG(int64_t, num_features); + + /// The epsilon value added for numerical stability. + TORCH_ARG(double, eps) = 1e-5; + + /// A momentum multiplier for the mean and variance. + TORCH_ARG(double, momentum) = 0.1; + + /// Whether to learn a scale and bias that are applied in an affine + /// transformation on the input. + TORCH_ARG(bool, affine) = false; + + /// Whether to store and update batch statistics (mean and variance) in the + /// module. + TORCH_ARG(bool, track_running_stats) = false; +}; + +/// Options for the `InstanceNorm1d` module. +/// +/// Example: +/// ``` +/// InstanceNorm1d +/// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +using InstanceNorm1dOptions = InstanceNormOptions; + +/// Options for the `InstanceNorm2d` module. +/// +/// Example: +/// ``` +/// InstanceNorm2d +/// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +using InstanceNorm2dOptions = InstanceNormOptions; + +/// Options for the `InstanceNorm3d` module. +/// +/// Example: +/// ``` +/// InstanceNorm3d +/// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); +/// ``` +using InstanceNorm3dOptions = InstanceNormOptions; + +namespace functional { + +/// Options for `torch::nn::functional::instance_norm`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::instance_norm(input, +/// F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5)); +/// ``` +struct TORCH_API InstanceNormFuncOptions { + TORCH_ARG(Tensor, running_mean) = Tensor(); + + TORCH_ARG(Tensor, running_var) = Tensor(); + + TORCH_ARG(Tensor, weight) = Tensor(); + + TORCH_ARG(Tensor, bias) = Tensor(); + + TORCH_ARG(bool, use_input_stats) = true; + + TORCH_ARG(double, momentum) = 0.1; + + TORCH_ARG(double, eps) = 1e-5; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/linear.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/linear.h new file mode 100644 index 0000000000000000000000000000000000000000..6c045910b848c909d0d6727709cb4aa1a4932ad3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/linear.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn { + +/// Options for the `Linear` module. +/// +/// Example: +/// ``` +/// Linear model(LinearOptions(5, 2).bias(false)); +/// ``` +struct TORCH_API LinearOptions { + LinearOptions(int64_t in_features, int64_t out_features); + /// size of each input sample + TORCH_ARG(int64_t, in_features); + + /// size of each output sample + TORCH_ARG(int64_t, out_features); + + /// If set to false, the layer will not learn an additive bias. Default: true + TORCH_ARG(bool, bias) = true; +}; + +// ============================================================================ + +/// Options for the `Flatten` module. +/// +/// Example: +/// ``` +/// Flatten model(FlattenOptions().start_dim(2).end_dim(4)); +/// ``` +struct TORCH_API FlattenOptions { + /// first dim to flatten + TORCH_ARG(int64_t, start_dim) = 1; + /// last dim to flatten + TORCH_ARG(int64_t, end_dim) = -1; +}; + +// ============================================================================ + +/// Options for the `Unflatten` module. +/// +/// Note: If input tensor is named, use dimname and namedshape arguments. +/// +/// Example: +/// ``` +/// Unflatten unnamed_model(UnflattenOptions(0, {2, 2})); +/// Unflatten named_model(UnflattenOptions("B", {{"B1", 2}, {"B2", 2}})); +/// ``` +struct TORCH_API UnflattenOptions { + typedef std::vector> namedshape_t; + + UnflattenOptions(int64_t dim, std::vector sizes); + UnflattenOptions(const char* dimname, namedshape_t namedshape); + UnflattenOptions(std::string dimname, namedshape_t namedshape); + + /// dim to unflatten + TORCH_ARG(int64_t, dim); + /// name of dim to unflatten, for use with named tensors + TORCH_ARG(std::string, dimname); + /// new shape of unflattened dim + TORCH_ARG(std::vector, sizes); + /// new shape of unflattened dim with names, for use with named tensors + TORCH_ARG(namedshape_t, namedshape); +}; + +// ============================================================================ + +/// Options for the `Bilinear` module. +/// +/// Example: +/// ``` +/// Bilinear model(BilinearOptions(3, 2, 4).bias(false)); +/// ``` +struct TORCH_API BilinearOptions { + BilinearOptions( + int64_t in1_features, + int64_t in2_features, + int64_t out_features); + /// The number of features in input 1 (columns of the input1 matrix). + TORCH_ARG(int64_t, in1_features); + /// The number of features in input 2 (columns of the input2 matrix). + TORCH_ARG(int64_t, in2_features); + /// The number of output features to produce (columns of the output matrix). + TORCH_ARG(int64_t, out_features); + /// Whether to learn and add a bias after the bilinear transformation. + TORCH_ARG(bool, bias) = true; +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/loss.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/loss.h new file mode 100644 index 0000000000000000000000000000000000000000..88d954c5e18b5f6a5b772bed2130239548ca34cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/loss.h @@ -0,0 +1,800 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for the `L1Loss` module. +/// +/// Example: +/// ``` +/// L1Loss model(L1LossOptions(torch::kNone)); +/// ``` +struct TORCH_API L1LossOptions { + typedef std::variant + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG3(L1LossOptions, reduction, kNone, kMean, kSum) + + /// Specifies the reduction to apply to the output. + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::l1_loss`. +/// +/// See the documentation for `torch::nn::L1LossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::l1_loss(input, target, F::L1LossFuncOptions(torch::kNone)); +/// ``` +using L1LossFuncOptions = L1LossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `KLDivLoss` module. +/// +/// Example: +/// ``` +/// KLDivLoss +/// model(KLDivLossOptions().reduction(torch::kNone).log_target(false)); +/// ``` +struct TORCH_API KLDivLossOptions { + typedef std::variant< + enumtype::kNone, + enumtype::kBatchMean, + enumtype::kSum, + enumtype::kMean> + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG4( + KLDivLossOptions, + reduction, + kNone, + kBatchMean, + kSum, + kMean) + + /// Specifies the reduction to apply to the output. + /// ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. Default: ``'mean'`` + TORCH_ARG(reduction_t, reduction) = torch::kMean; + + /// Specifies whether `target` is accepted in the log space. Default: False + TORCH_ARG(bool, log_target) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::kl_div`. +/// +/// See the documentation for `torch::nn::KLDivLossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::kl_div(input, target, +/// F::KLDivFuncOptions().reduction(torch::kNone).log_target(false)); +/// ``` +using KLDivFuncOptions = KLDivLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `MSELoss` module. +/// +/// Example: +/// ``` +/// MSELoss model(MSELossOptions(torch::kNone)); +/// ``` +struct TORCH_API MSELossOptions { + typedef std::variant + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG3(MSELossOptions, reduction, kNone, kMean, kSum) + + /// Specifies the reduction to apply to the output. + /// ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::mse_loss`. +/// +/// See the documentation for `torch::nn::MSELossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::mse_loss(input, target, F::MSELossFuncOptions(torch::kNone)); +/// ``` +using MSELossFuncOptions = MSELossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `BCELoss` module. +/// +/// Example: +/// ``` +/// BCELoss model(BCELossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API BCELossOptions { + typedef std::variant + reduction_t; + + /// A manual rescaling weight given to the loss of each batch element. + TORCH_ARG(Tensor, weight) = {}; + /// Specifies the reduction to apply to the output. + /// ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::binary_cross_entropy`. +/// +/// See the documentation for `torch::nn::BCELossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::binary_cross_entropy(input, target, +/// F::BinaryCrossEntropyFuncOptions().weight(weight)); +/// ``` +using BinaryCrossEntropyFuncOptions = BCELossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `HingeEmbeddingLoss` module. +/// +/// Example: +/// ``` +/// HingeEmbeddingLoss +/// model(HingeEmbeddingLossOptions().margin(4).reduction(torch::kNone)); +/// ``` +struct TORCH_API HingeEmbeddingLossOptions { + typedef std::variant + reduction_t; + + /// Specifies the threshold for which the distance of a negative sample must + /// reach in order to incur zero loss. Default: 1 + TORCH_ARG(double, margin) = 1.0; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::hinge_embedding_loss`. +/// +/// See the documentation for `torch::nn::HingeEmbeddingLossOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::hinge_embedding_loss(input, target, +/// F::HingeEmbeddingLossFuncOptions().margin(2)); +/// ``` +using HingeEmbeddingLossFuncOptions = HingeEmbeddingLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `MultiMarginLoss` module. +/// +/// Example: +/// ``` +/// MultiMarginLoss model(MultiMarginLossOptions().margin(2).weight(weight)); +/// ``` +struct TORCH_API MultiMarginLossOptions { + typedef std::variant + reduction_t; + + /// Has a default value of :math:`1`. :math:`1` and :math:`2` + /// are the only supported values. + TORCH_ARG(int64_t, p) = 1; + /// Has a default value of :math:`1`. + TORCH_ARG(double, margin) = 1.0; + /// A manual rescaling weight given to each + /// class. If given, it has to be a Tensor of size `C`. Otherwise, it is + /// treated as if having all ones. + TORCH_ARG(Tensor, weight) = Tensor(); + /// Specifies the reduction to apply to the output: + /// ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be + /// applied, + /// ``'mean'``: the sum of the output will be divided by the number of + /// elements in the output, ``'sum'``: the output will be summed. Default: + /// ``'mean'`` + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::multi_margin_loss`. +/// +/// See the documentation for `torch::nn::MultiMarginLossOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::multi_margin_loss(input, target, +/// F::MultiMarginLossFuncOptions().margin(2).weight(weight)); +/// ``` +using MultiMarginLossFuncOptions = MultiMarginLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `CosineEmbeddingLoss` module. +/// +/// Example: +/// ``` +/// CosineEmbeddingLoss model(CosineEmbeddingLossOptions().margin(0.5)); +/// ``` +struct TORCH_API CosineEmbeddingLossOptions { + typedef std::variant + reduction_t; + + /// Specifies the threshold for which the distance of a negative sample must + /// reach in order to incur zero loss. Should be a number from -1 to 1, 0 + /// to 0.5 is suggested. Default: 0.0 + TORCH_ARG(double, margin) = 0.0; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::cosine_embedding_loss`. +/// +/// See the documentation for `torch::nn::CosineEmbeddingLossOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::cosine_embedding_loss(input1, input2, target, +/// F::CosineEmbeddingLossFuncOptions().margin(0.5)); +/// ``` +using CosineEmbeddingLossFuncOptions = CosineEmbeddingLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `MultiLabelMarginLoss` module. +/// +/// Example: +/// ``` +/// MultiLabelMarginLoss model(MultiLabelMarginLossOptions(torch::kNone)); +/// ``` +struct TORCH_API MultiLabelMarginLossOptions { + typedef std::variant + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + MultiLabelMarginLossOptions, + reduction, + kNone, + kMean, + kSum) + + /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + /// 'none': no reduction will be applied, 'mean': the sum of the output will + /// be divided by the number of elements in the output, 'sum': the output will + /// be summed. Default: 'mean' + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::multilabel_margin_loss`. +/// +/// See the documentation for `torch::nn::MultiLabelMarginLossOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::multilabel_margin_loss(input, target, +/// F::MultilabelMarginLossFuncOptions(torch::kNone)); +/// ``` +using MultilabelMarginLossFuncOptions = MultiLabelMarginLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `SoftMarginLoss` module. +/// +/// Example: +/// ``` +/// SoftMarginLoss model(SoftMarginLossOptions(torch::kNone)); +/// ``` +struct TORCH_API SoftMarginLossOptions { + typedef std::variant + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + SoftMarginLossOptions, + reduction, + kNone, + kMean, + kSum) + + /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + /// 'none': no reduction will be applied, 'mean': the sum of the output will + /// be divided by the number of elements in the output, 'sum': the output will + /// be summed. Default: 'mean' + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::soft_margin_loss`. +/// +/// See the documentation for `torch::nn::SoftMarginLossOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::soft_margin_loss(input, target, +/// F::SoftMarginLossFuncOptions(torch::kNone)); +/// ``` +using SoftMarginLossFuncOptions = SoftMarginLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `MultiLabelSoftMarginLoss` module. +/// +/// Example: +/// ``` +/// MultiLabelSoftMarginLoss +/// model(MultiLabelSoftMarginLossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API MultiLabelSoftMarginLossOptions { + typedef std::variant + reduction_t; + + /// A manual rescaling weight given to each + /// class. If given, it has to be a Tensor of size `C`. Otherwise, it is + /// treated as if having all ones. + TORCH_ARG(Tensor, weight) = Tensor(); + + /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + /// 'none': no reduction will be applied, 'mean': the sum of the output will + /// be divided by the number of elements in the output, 'sum': the output will + /// be summed. Default: 'mean' + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::multilabel_soft_margin_loss`. +/// +/// See the documentation for `torch::nn::MultiLabelSoftMarginLossOptions` class +/// to learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::multilabel_soft_margin_loss(input, target, +/// F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone).weight(weight)); +/// ``` +using MultilabelSoftMarginLossFuncOptions = MultiLabelSoftMarginLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `TripletMarginLoss` module. +/// +/// Example: +/// ``` +/// TripletMarginLoss +/// model(TripletMarginLossOptions().margin(3).p(2).eps(1e-06).swap(false)); +/// ``` +struct TORCH_API TripletMarginLossOptions { + typedef std::variant + reduction_t; + + /// Specifies the threshold for which the distance of a negative sample must + /// reach in order to incur zero loss. Default: 1 + TORCH_ARG(double, margin) = 1.0; + /// Specifies the norm degree for pairwise distance. Default: 2 + TORCH_ARG(double, p) = 2.0; + TORCH_ARG(double, eps) = 1e-6; + /// The distance swap is described in detail in the paper Learning shallow + /// convolutional feature descriptors with triplet losses by V. Balntas, + /// E. Riba et al. Default: False + TORCH_ARG(bool, swap) = false; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::triplet_margin_loss`. +/// +/// See the documentation for `torch::nn::TripletMarginLossOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::triplet_margin_loss(anchor, positive, negative, +/// F::TripletMarginLossFuncOptions().margin(1.0)); +/// ``` +using TripletMarginLossFuncOptions = TripletMarginLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `TripletMarginWithDistanceLoss` module. +/// +/// Example: +/// ``` +/// TripletMarginWithDistanceLoss +/// model(TripletMarginWithDistanceLossOptions().margin(3).swap(false)); +/// ``` +struct TORCH_API TripletMarginWithDistanceLossOptions { + typedef std::variant + reduction_t; + typedef std::function + distance_function_t; + + /// Specifies a nonnegative, real-valued function that quantifies the + /// closeness of two tensors. If not specified, `F::pairwise_distance` will + /// be used. Default: nullopt + TORCH_ARG(std::optional, distance_function) = + std::nullopt; + /// Specifies a nonnegative margin representing the minimum difference + /// between the positive and negative distances required for the loss to be 0. + /// Larger margins penalize cases where the negative examples are not distance + /// enough from the anchors, relative to the positives. Default: 1 + TORCH_ARG(double, margin) = 1.0; + /// Whether to use the distance swap described in the paper Learning shallow + /// convolutional feature descriptors with triplet losses by V. Balntas, + /// E. Riba et al. If True, and if the positive example is closer to the + /// negative example than the anchor is, swaps the positive example and the + /// anchor in the loss computation. Default: False + TORCH_ARG(bool, swap) = false; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::triplet_margin_with_distance_loss`. +/// +/// See the documentation for `torch::nn::TripletMarginWithDistanceLossOptions` +/// class to learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::triplet_margin_with_distance_loss(anchor, positive, negative, +/// F::TripletMarginWithDistanceLossFuncOptions().margin(1.0)); +/// ``` +using TripletMarginWithDistanceLossFuncOptions = + TripletMarginWithDistanceLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `CTCLoss` module. +/// +/// Example: +/// ``` +/// CTCLoss +/// model(CTCLossOptions().blank(42).zero_infinity(false).reduction(torch::kSum)); +/// ``` +struct TORCH_API CTCLossOptions { + typedef std::variant + reduction_t; + + /// blank label. Default `0`. + TORCH_ARG(int64_t, blank) = 0; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; + /// Whether to zero infinite losses and the associated gradients. + /// Default: `false`. Infinite losses mainly occur when the inputs are + /// too short to be aligned to the targets. + TORCH_ARG(bool, zero_infinity) = false; +}; + +namespace functional { +/// Options for `torch::nn::functional::ctc_loss`. +/// +/// See the documentation for `torch::nn::CTCLossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::ctc_loss(log_probs, targets, input_lengths, target_lengths, +/// F::CTCLossFuncOptions().reduction(torch::kNone)); +/// ``` +using CTCLossFuncOptions = CTCLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `SmoothL1Loss` module. +/// +/// Example: +/// ``` +/// SmoothL1Loss model(SmoothL1LossOptions().reduction(torch::kNone).beta(0.5)); +/// ``` +struct TORCH_API SmoothL1LossOptions { + typedef std::variant + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + SmoothL1LossOptions, + reduction, + kNone, + kMean, + kSum) + + /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + /// 'none': no reduction will be applied, 'mean': the sum of the output will + /// be divided by the number of elements in the output, 'sum': the output will + /// be summed. Default: 'mean' + TORCH_ARG(reduction_t, reduction) = torch::kMean; + /// Specifies the threshold at which to change between L1 and L2 loss. + /// If beta is not specified, a value of 1.0 will be used. + /// Default: nullopt + TORCH_ARG(std::optional, beta) = std::nullopt; +}; + +namespace functional { +/// Options for `torch::nn::functional::smooth_l1_loss`. +/// +/// See the documentation for `torch::nn::SmoothL1LossOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::smooth_l1_loss(input, target, F::SmoothL1LossFuncOptions(torch::kNone)); +/// ``` +using SmoothL1LossFuncOptions = SmoothL1LossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `HuberLoss` module. +/// +/// Example: +/// ``` +/// HuberLoss model(HuberLossOptions().reduction(torch::kNone).delta(0.5)); +/// ``` +struct TORCH_API HuberLossOptions { + typedef std::variant + reduction_t; + + TORCH_OPTIONS_CTOR_VARIANT_ARG3( + HuberLossOptions, + reduction, + kNone, + kMean, + kSum) + + /// Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + /// 'none': no reduction will be applied, 'mean': the sum of the output will + /// be divided by the number of elements in the output, 'sum': the output will + /// be summed. Default: 'mean' + TORCH_ARG(reduction_t, reduction) = torch::kMean; + /// Specifies the threshold at which to change between L1 and L2 loss. + /// Default: 1.0 + TORCH_ARG(double, delta) = 1.0; +}; + +namespace functional { +/// Options for `torch::nn::functional::huber_loss`. +/// +/// See the documentation for `torch::nn::HuberLossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::huber_loss(input, target, F::HuberLossFuncOptions(torch::kNone)); +/// ``` +using HuberLossFuncOptions = HuberLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `PoissonNLLLoss` module. +/// +/// Example: +/// ``` +/// PoissonNLLLoss +/// model(PoissonNLLLossOptions().log_input(false).full(true).eps(0.42).reduction(torch::kSum)); +/// ``` +struct TORCH_API PoissonNLLLossOptions { + typedef std::variant + reduction_t; + + /// if true the loss is computed as `exp(input) - target * input`, + /// if false the loss is `input - target * log(input + eps)`. + TORCH_ARG(bool, log_input) = true; + /// whether to compute full loss, i.e. to add the Stirling approximation term + /// target * log(target) - target + 0.5 * log(2 * pi * target). + TORCH_ARG(bool, full) = false; + /// Small value to avoid evaluation of `log(0)` when `log_input = false`. + /// Default: 1e-8 + TORCH_ARG(double, eps) = 1e-8; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::poisson_nll_loss`. +/// +/// See the documentation for `torch::nn::PoissonNLLLossOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::poisson_nll_loss(input, target, +/// F::PoissonNLLLossFuncOptions().reduction(torch::kNone)); +/// ``` +using PoissonNLLLossFuncOptions = PoissonNLLLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `MarginRankingLoss` module. +/// +/// Example: +/// ``` +/// MarginRankingLoss +/// model(MarginRankingLossOptions().margin(0.5).reduction(torch::kSum)); +/// ``` +struct TORCH_API MarginRankingLossOptions { + typedef std::variant + reduction_t; + + /// Has a default value of `0`. + TORCH_ARG(double, margin) = 0; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::margin_ranking_loss`. +/// +/// See the documentation for `torch::nn::MarginRankingLossOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::margin_ranking_loss(input1, input2, target, +/// F::MarginRankingLossFuncOptions().margin(0.5).reduction(torch::kSum)); +/// ``` +using MarginRankingLossFuncOptions = MarginRankingLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `NLLLoss` module. +/// +/// Example: +/// ``` +/// NLLLoss model(NLLLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +struct TORCH_API NLLLossOptions { + typedef std::variant + reduction_t; + + /// A manual rescaling weight given to each + /// class. If given, it has to be a Tensor of size `C`. Otherwise, it is + /// treated as if having all ones. + TORCH_ARG(Tensor, weight) = {}; + /// Specifies a target value that is ignored + /// and does not contribute to the input gradient. + TORCH_ARG(int64_t, ignore_index) = -100; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; +}; + +namespace functional { +/// Options for `torch::nn::functional::nll_loss`. +/// +/// See the documentation for `torch::nn::NLLLossOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::nll_loss(input, target, +/// F::NLLLossFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +using NLLLossFuncOptions = NLLLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `CrossEntropyLoss` module. +/// +/// Example: +/// ``` +/// CrossEntropyLoss +/// model(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +struct TORCH_API CrossEntropyLossOptions { + typedef std::variant + reduction_t; + + /// A manual rescaling weight given to each class. If given, has to be a + /// Tensor of size C + TORCH_ARG(Tensor, weight) = {}; + /// Specifies a target value that is ignored + /// and does not contribute to the input gradient. + TORCH_ARG(int64_t, ignore_index) = -100; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; + /// Specifies the amount of smoothing when computing the loss. Default: 0.0 + TORCH_ARG(double, label_smoothing) = 0.0; +}; + +namespace functional { +/// Options for `torch::nn::functional::cross_entropy`. +/// +/// See the documentation for `torch::nn::CrossEntropyLossOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::cross_entropy(input, target, +/// F::CrossEntropyFuncOptions().ignore_index(-100).reduction(torch::kMean)); +/// ``` +using CrossEntropyFuncOptions = CrossEntropyLossOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `BCEWithLogitsLoss` module. +/// +/// Example: +/// ``` +/// BCEWithLogitsLoss +/// model(BCEWithLogitsLossOptions().reduction(torch::kNone).weight(weight)); +/// ``` +struct TORCH_API BCEWithLogitsLossOptions { + typedef std::variant + reduction_t; + /// A manual rescaling weight given to the loss of each batch element. + /// If given, has to be a Tensor of size `nbatch`. + TORCH_ARG(Tensor, weight) = {}; + /// Specifies the reduction to apply to the output. Default: Mean + TORCH_ARG(reduction_t, reduction) = torch::kMean; + /// A weight of positive examples. + /// Must be a vector with length equal to the number of classes. + TORCH_ARG(Tensor, pos_weight) = {}; +}; + +namespace functional { +/// Options for `torch::nn::functional::binary_cross_entropy_with_logits`. +/// +/// See the documentation for `torch::nn::BCEWithLogitsLossOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::binary_cross_entropy_with_logits(input, target, +/// F::BinaryCrossEntropyWithLogitsFuncOptions().pos_weight(pos_weight).reduction(torch::kSum)); +/// ``` +using BinaryCrossEntropyWithLogitsFuncOptions = BCEWithLogitsLossOptions; +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/normalization.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..6097a2923af2fc37fdf31cb2d3d247718b606ff4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/normalization.h @@ -0,0 +1,190 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for the `LayerNorm` module. +/// +/// Example: +/// ``` +/// LayerNorm model(LayerNormOptions({2, +/// 2}).elementwise_affine(false).eps(2e-5)); +/// ``` +struct TORCH_API LayerNormOptions { + /* implicit */ LayerNormOptions(std::vector normalized_shape); + /// input shape from an expected input. + TORCH_ARG(std::vector, normalized_shape); + /// a value added to the denominator for numerical stability. ``Default: + /// 1e-5``. + TORCH_ARG(double, eps) = 1e-5; + /// a boolean value that when set to ``true``, this module + /// has learnable per-element affine parameters initialized to ones (for + /// weights) and zeros (for biases). ``Default: true``. + TORCH_ARG(bool, elementwise_affine) = true; +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::layer_norm`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5)); +/// ``` +struct TORCH_API LayerNormFuncOptions { + /* implicit */ LayerNormFuncOptions(std::vector normalized_shape); + /// input shape from an expected input. + TORCH_ARG(std::vector, normalized_shape); + + TORCH_ARG(Tensor, weight) = {}; + + TORCH_ARG(Tensor, bias) = {}; + + /// a value added to the denominator for numerical stability. ``Default: + /// 1e-5``. + TORCH_ARG(double, eps) = 1e-5; +}; + +} // namespace functional + +// ============================================================================ + +/// Options for the `LocalResponseNorm` module. +/// +/// Example: +/// ``` +/// LocalResponseNorm +/// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); +/// ``` +struct TORCH_API LocalResponseNormOptions { + /* implicit */ LocalResponseNormOptions(int64_t size) : size_(size) {} + /// amount of neighbouring channels used for normalization + TORCH_ARG(int64_t, size); + + /// multiplicative factor. Default: 1e-4 + TORCH_ARG(double, alpha) = 1e-4; + + /// exponent. Default: 0.75 + TORCH_ARG(double, beta) = 0.75; + + /// additive factor. Default: 1 + TORCH_ARG(double, k) = 1.; +}; + +namespace functional { +/// Options for `torch::nn::functional::local_response_norm`. +/// +/// See the documentation for `torch::nn::LocalResponseNormOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::local_response_norm(x, F::LocalResponseNormFuncOptions(2)); +/// ``` +using LocalResponseNormFuncOptions = LocalResponseNormOptions; +} // namespace functional + +// ============================================================================ + +/// Options for the `CrossMapLRN2d` module. +/// +/// Example: +/// ``` +/// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10)); +/// ``` +struct TORCH_API CrossMapLRN2dOptions { + CrossMapLRN2dOptions(int64_t size); + + TORCH_ARG(int64_t, size); + + TORCH_ARG(double, alpha) = 1e-4; + + TORCH_ARG(double, beta) = 0.75; + + TORCH_ARG(int64_t, k) = 1; +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::normalize`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); +/// ``` +struct TORCH_API NormalizeFuncOptions { + /// The exponent value in the norm formulation. Default: 2.0 + TORCH_ARG(double, p) = 2.0; + /// The dimension to reduce. Default: 1 + TORCH_ARG(int64_t, dim) = 1; + /// Small value to avoid division by zero. Default: 1e-12 + TORCH_ARG(double, eps) = 1e-12; + /// the output tensor. If `out` is used, this + /// operation won't be differentiable. + TORCH_ARG(std::optional, out) = std::nullopt; +}; + +} // namespace functional + +// ============================================================================ + +/// Options for the `GroupNorm` module. +/// +/// Example: +/// ``` +/// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false)); +/// ``` +struct TORCH_API GroupNormOptions { + /* implicit */ GroupNormOptions(int64_t num_groups, int64_t num_channels); + + /// number of groups to separate the channels into + TORCH_ARG(int64_t, num_groups); + /// number of channels expected in input + TORCH_ARG(int64_t, num_channels); + /// a value added to the denominator for numerical stability. Default: 1e-5 + TORCH_ARG(double, eps) = 1e-5; + /// a boolean value that when set to ``true``, this module + /// has learnable per-channel affine parameters initialized to ones (for + /// weights) and zeros (for biases). Default: ``true``. + TORCH_ARG(bool, affine) = true; +}; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::group_norm`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5)); +/// ``` +struct TORCH_API GroupNormFuncOptions { + /* implicit */ GroupNormFuncOptions(int64_t num_groups); + + /// number of groups to separate the channels into + TORCH_ARG(int64_t, num_groups); + + TORCH_ARG(Tensor, weight) = {}; + + TORCH_ARG(Tensor, bias) = {}; + + /// a value added to the denominator for numerical stability. Default: 1e-5 + TORCH_ARG(double, eps) = 1e-5; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/padding.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/padding.h new file mode 100644 index 0000000000000000000000000000000000000000..efe71cff29005ed7a60d38c262138553fe7ef105 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/padding.h @@ -0,0 +1,217 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for a `D`-dimensional ReflectionPad module. +template +struct TORCH_API ReflectionPadOptions { + ReflectionPadOptions(ExpandingArray padding) : padding_(padding) {} + + /// The size of the padding. + /// If it is `int`, uses the same padding in all boundaries. + /// If it is a 2-`tuple` (for ReflectionPad1d), uses (padding_left, + /// padding_right). If it is a 4-`tuple` (for ReflectionPad2d), uses + /// (padding_left, padding_right, padding_top, padding_bottom). If it is a + /// 6-`tuple` (for ReflectionPad3d), uses (padding_left, padding_right, + /// padding_top, padding_bottom, padding_front, padding_back). + + TORCH_ARG(ExpandingArray, padding); +}; + +/// `ReflectionPadOptions` specialized for the `ReflectionPad1d` module. +/// +/// Example: +/// ``` +/// ReflectionPad1d model(ReflectionPad1dOptions({3, 1})); +/// ``` +using ReflectionPad1dOptions = ReflectionPadOptions<1>; + +/// `ReflectionPadOptions` specialized for the `ReflectionPad2d` module. +/// +/// Example: +/// ``` +/// ReflectionPad2d model(ReflectionPad2dOptions({1, 1, 2, 0})); +/// ``` +using ReflectionPad2dOptions = ReflectionPadOptions<2>; + +/// `ReflectionPadOptions` specialized for the `ReflectionPad3d` module. +/// +/// Example: +/// ``` +/// ReflectionPad3d model(ReflectionPad3dOptions({1, 1, 2, 0, 1, 1})); +/// ``` +using ReflectionPad3dOptions = ReflectionPadOptions<3>; + +// ============================================================================ + +/// Options for a `D`-dimensional ReplicationPad module. +template +struct TORCH_API ReplicationPadOptions { + ReplicationPadOptions(ExpandingArray padding) : padding_(padding) {} + + /// The size of the padding. + /// - If it is `int`, uses the same padding in all boundaries. + /// - If it is a 2-`tuple` (for ReplicationPad1d), uses (padding_left, + /// padding_right). + /// - If it is a 4-`tuple` (for ReplicationPad2d), uses (padding_left, + /// padding_right, padding_top, padding_bottom). + /// - If it is a 6-`tuple` (for ReplicationPad3d), uses + /// (padding_left, padding_right, padding_top, padding_bottom, + /// padding_front, padding_back). + TORCH_ARG(ExpandingArray, padding); +}; + +/// `ReplicationPadOptions` specialized for the `ReplicationPad1d` module. +/// +/// Example: +/// ``` +/// ReplicationPad1d model(ReplicationPad1dOptions({3, 1})); +/// ``` +using ReplicationPad1dOptions = ReplicationPadOptions<1>; + +/// `ReplicationPadOptions` specialized for the `ReplicationPad2d` module. +/// +/// Example: +/// ``` +/// ReplicationPad2d model(ReplicationPad2dOptions({1, 1, 2, 0})); +/// ``` +using ReplicationPad2dOptions = ReplicationPadOptions<2>; + +/// `ReplicationPadOptions` specialized for the `ReplicationPad3d` module. +/// +/// Example: +/// ``` +/// ReplicationPad3d model(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2})); +/// ``` +using ReplicationPad3dOptions = ReplicationPadOptions<3>; + +// ============================================================================ + +template +struct TORCH_API ZeroPadOptions { + ZeroPadOptions(ExpandingArray padding) : padding_(padding) {} + + /// The size of the padding. + /// - If it is `int`, uses the same padding in all boundaries. + /// - If it is a 2-`tuple` (for ZeroPad1d), uses (padding_left, + /// padding_right). + /// - If it is a 4-`tuple` (for ZeroPad2d), uses (padding_left, padding_right, + /// padding_top, padding_bottom). + /// - If it is a 6-`tuple` (for ZeroPad3d), uses + /// (padding_left, padding_right, padding_top, padding_bottom, + /// padding_front, padding_back). + TORCH_ARG(ExpandingArray, padding); +}; + +/// `ZeroPadOptions` specialized for the `ZeroPad1d` module. +/// +/// Example: +/// ``` +/// ConstantPad1d model(ConstantPad1dOptions({3, 1}); +/// ``` +using ZeroPad1dOptions = ZeroPadOptions<1>; + +/// `ZeroPadOptions` specialized for the `ZeroPad2d` module. +/// +/// Example: +/// ``` +/// ConstantPad2d model(ConstantPad2dOptions({1, 1, 2, 0}); +/// ``` +using ZeroPad2dOptions = ZeroPadOptions<2>; + +/// `ZeroPadOptions` specialized for the `ZeroPad3d` module. +/// +/// Example: +/// ``` +/// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}); +/// ``` +using ZeroPad3dOptions = ZeroPadOptions<3>; + +// ============================================================================ + +/// Options for a `D`-dimensional ConstantPad module. +template +struct TORCH_API ConstantPadOptions { + ConstantPadOptions(ExpandingArray padding, double value) + : padding_(padding), value_(value) {} + + /// The size of the padding. + /// - If it is `int`, uses the same padding in all boundaries. + /// - If it is a 2-`tuple` (for ConstantPad1d), uses (padding_left, + /// padding_right). + /// - If it is a 4-`tuple` (for ConstantPad2d), uses (padding_left, + /// padding_right, padding_top, padding_bottom). + /// - If it is a 6-`tuple` (for ConstantPad3d), uses + /// (padding_left, padding_right, padding_top, padding_bottom, + /// padding_front, padding_back). + TORCH_ARG(ExpandingArray, padding); + + /// Fill value for constant padding. + TORCH_ARG(double, value); +}; + +/// `ConstantPadOptions` specialized for the `ConstantPad1d` module. +/// +/// Example: +/// ``` +/// ConstantPad1d model(ConstantPad1dOptions({3, 1}, 3.5)); +/// ``` +using ConstantPad1dOptions = ConstantPadOptions<1>; + +/// `ConstantPadOptions` specialized for the `ConstantPad2d` module. +/// +/// Example: +/// ``` +/// ConstantPad2d model(ConstantPad2dOptions({3, 0, 2, 1}, 3.5)); +/// ``` +using ConstantPad2dOptions = ConstantPadOptions<2>; + +/// `ConstantPadOptions` specialized for the `ConstantPad3d` module. +/// +/// Example: +/// ``` +/// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5)); +/// ``` +using ConstantPad3dOptions = ConstantPadOptions<3>; + +// ============================================================================ + +namespace functional { + +/// Options for `torch::nn::functional::pad`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1, +/// 2}).mode(torch::kReplicate)); +/// ``` +struct TORCH_API PadFuncOptions { + typedef std::variant< + enumtype::kConstant, + enumtype::kReflect, + enumtype::kReplicate, + enumtype::kCircular> + mode_t; + + PadFuncOptions(std::vector pad); + + /// m-elements tuple, where m/2 <= input dimensions and m is even. + TORCH_ARG(std::vector, pad); + + /// "constant", "reflect", "replicate" or "circular". Default: "constant" + TORCH_ARG(mode_t, mode) = torch::kConstant; + + /// fill value for "constant" padding. Default: 0 + TORCH_ARG(double, value) = 0; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/pixelshuffle.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/pixelshuffle.h new file mode 100644 index 0000000000000000000000000000000000000000..8de36fb614861cb58138d83f190798d0e7e17e6c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/pixelshuffle.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include + +namespace torch::nn { + +/// Options for the `PixelShuffle` module. +/// +/// Example: +/// ``` +/// PixelShuffle model(PixelShuffleOptions(5)); +/// ``` +struct TORCH_API PixelShuffleOptions { + PixelShuffleOptions(int64_t upscale_factor) + : upscale_factor_(upscale_factor) {} + + /// Factor to increase spatial resolution by + TORCH_ARG(int64_t, upscale_factor); +}; + +/// Options for the `PixelUnshuffle` module. +/// +/// Example: +/// ``` +/// PixelUnshuffle model(PixelUnshuffleOptions(5)); +/// ``` +struct TORCH_API PixelUnshuffleOptions { + /* implicit */ PixelUnshuffleOptions(int64_t downscale_factor) + : downscale_factor_(downscale_factor) {} + + /// Factor to decrease spatial resolution by + TORCH_ARG(int64_t, downscale_factor); +}; + +namespace functional { +/// Options for `torch::nn::functional::pixel_shuffle`. +/// +/// See the documentation for `torch::nn::PixelShuffleOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pixel_shuffle(x, F::PixelShuffleFuncOptions(2)); +/// ``` +using PixelShuffleFuncOptions = PixelShuffleOptions; + +/// Options for `torch::nn::functional::pixel_unshuffle`. +/// +/// See the documentation for `torch::nn::PixelUnshuffleOptions` class to learn +/// what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::pixel_unshuffle(x, F::PixelUnshuffleFuncOptions(2)); +/// ``` +using PixelUnshuffleFuncOptions = PixelUnshuffleOptions; +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/pooling.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..3934f326c8a5da08241b93e51534b7b06acb27fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/pooling.h @@ -0,0 +1,594 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +/// Options for a `D`-dimensional avgpool module. +template +struct AvgPoolOptions { + AvgPoolOptions(ExpandingArray kernel_size) + : kernel_size_(kernel_size), stride_(kernel_size) {} + + /// the size of the window to take an average over + TORCH_ARG(ExpandingArray, kernel_size); + + /// the stride of the window. Default value is `kernel_size` + TORCH_ARG(ExpandingArray, stride); + + /// implicit zero padding to be added on both sides + TORCH_ARG(ExpandingArray, padding) = 0; + + /// when True, will use `ceil` instead of `floor` to compute the output shape + TORCH_ARG(bool, ceil_mode) = false; + + /// when True, will include the zero-padding in the averaging calculation + TORCH_ARG(bool, count_include_pad) = true; + + /// if specified, it will be used as divisor, otherwise size of the pooling + /// region will be used. + + TORCH_ARG(std::optional, divisor_override) = std::nullopt; +}; + +/// `AvgPoolOptions` specialized for the `AvgPool1d` module. +/// +/// Example: +/// ``` +/// AvgPool1d model(AvgPool1dOptions(3).stride(2)); +/// ``` +using AvgPool1dOptions = AvgPoolOptions<1>; + +/// `AvgPoolOptions` specialized for the `AvgPool2d` module. +/// +/// Example: +/// ``` +/// AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2})); +/// ``` +using AvgPool2dOptions = AvgPoolOptions<2>; + +/// `AvgPoolOptions` specialized for the `AvgPool3d` module. +/// +/// Example: +/// ``` +/// AvgPool3d model(AvgPool3dOptions(5).stride(2)); +/// ``` +using AvgPool3dOptions = AvgPoolOptions<3>; + +namespace functional { +/// Options for `torch::nn::functional::avg_pool1d`. +/// +/// See the documentation for `torch::nn::AvgPool1dOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::avg_pool1d(x, F::AvgPool1dFuncOptions(3).stride(2)); +/// ``` +using AvgPool1dFuncOptions = AvgPool1dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::avg_pool2d`. +/// +/// See the documentation for `torch::nn::AvgPool2dOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::avg_pool2d(x, F::AvgPool2dFuncOptions(3).stride(2)); +/// ``` +using AvgPool2dFuncOptions = AvgPool2dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::avg_pool3d`. +/// +/// See the documentation for `torch::nn::AvgPool3dOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::avg_pool3d(x, F::AvgPool3dFuncOptions(3).stride(2)); +/// ``` +using AvgPool3dFuncOptions = AvgPool3dOptions; +} // namespace functional + +// ============================================================================ + +/// Options for a `D`-dimensional maxpool module. +template +struct MaxPoolOptions { + MaxPoolOptions(ExpandingArray kernel_size) + : kernel_size_(kernel_size), stride_(kernel_size) {} + + /// the size of the window to take a max over + TORCH_ARG(ExpandingArray, kernel_size); + + /// the stride of the window. Default value is `kernel_size + TORCH_ARG(ExpandingArray, stride); + + /// implicit zero padding to be added on both sides + TORCH_ARG(ExpandingArray, padding) = 0; + + /// a parameter that controls the stride of elements in the window + TORCH_ARG(ExpandingArray, dilation) = 1; + + /// when True, will use `ceil` instead of `floor` to compute the output shape + TORCH_ARG(bool, ceil_mode) = false; +}; + +/// `MaxPoolOptions` specialized for the `MaxPool1d` module. +/// +/// Example: +/// ``` +/// MaxPool1d model(MaxPool1dOptions(3).stride(2)); +/// ``` +using MaxPool1dOptions = MaxPoolOptions<1>; + +/// `MaxPoolOptions` specialized for the `MaxPool2d` module. +/// +/// Example: +/// ``` +/// MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2})); +/// ``` +using MaxPool2dOptions = MaxPoolOptions<2>; + +/// `MaxPoolOptions` specialized for the `MaxPool3d` module. +/// +/// Example: +/// ``` +/// MaxPool3d model(MaxPool3dOptions(3).stride(2)); +/// ``` +using MaxPool3dOptions = MaxPoolOptions<3>; + +namespace functional { +/// Options for `torch::nn::functional::max_pool1d` and +/// `torch::nn::functional::max_pool1d_with_indices`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool1d(x, F::MaxPool1dFuncOptions(3).stride(2)); +/// ``` +using MaxPool1dFuncOptions = MaxPool1dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::max_pool2d` and +/// `torch::nn::functional::max_pool2d_with_indices`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool2d(x, F::MaxPool2dFuncOptions(3).stride(2)); +/// ``` +using MaxPool2dFuncOptions = MaxPool2dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::max_pool3d` and +/// `torch::nn::functional::max_pool3d_with_indices`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_pool3d(x, F::MaxPool3dFuncOptions(3).stride(2)); +/// ``` +using MaxPool3dFuncOptions = MaxPool3dOptions; +} // namespace functional + +// ============================================================================ + +/// Options for a `D`-dimensional adaptive maxpool module. +template +struct AdaptiveMaxPoolOptions { + AdaptiveMaxPoolOptions(output_size_t output_size) + : output_size_(output_size) {} + + /// the target output size + TORCH_ARG(output_size_t, output_size); +}; + +/// `AdaptiveMaxPoolOptions` specialized for the `AdaptiveMaxPool1d` module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool1d model(AdaptiveMaxPool1dOptions(3)); +/// ``` +using AdaptiveMaxPool1dOptions = AdaptiveMaxPoolOptions>; + +/// `AdaptiveMaxPoolOptions` specialized for the `AdaptiveMaxPool2d` module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool2d model(AdaptiveMaxPool2dOptions({3, 2})); +/// ``` +using AdaptiveMaxPool2dOptions = + AdaptiveMaxPoolOptions>; + +/// `AdaptiveMaxPoolOptions` specialized for the `AdaptiveMaxPool3d` module. +/// +/// Example: +/// ``` +/// AdaptiveMaxPool3d model(AdaptiveMaxPool3dOptions(3)); +/// ``` +using AdaptiveMaxPool3dOptions = + AdaptiveMaxPoolOptions>; + +namespace functional { +/// Options for `torch::nn::functional::adaptive_max_pool1d` and +/// `torch::nn::functional::adaptive_max_pool1d_with_indices` +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool1d(x, F::AdaptiveMaxPool1dFuncOptions(3)); +/// ``` +using AdaptiveMaxPool1dFuncOptions = AdaptiveMaxPool1dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::adaptive_max_pool2d` and +/// `torch::nn::functional::adaptive_max_pool2d_with_indices` +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool2d(x, F::AdaptiveMaxPool2dFuncOptions(3)); +/// ``` +using AdaptiveMaxPool2dFuncOptions = AdaptiveMaxPool2dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::adaptive_max_pool3d` and +/// `torch::nn::functional::adaptive_max_pool3d_with_indices` +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_max_pool3d(x, F::AdaptiveMaxPool3dFuncOptions(3)); +/// ``` +using AdaptiveMaxPool3dFuncOptions = AdaptiveMaxPool3dOptions; +} // namespace functional + +// ============================================================================ + +/// Options for a `D`-dimensional adaptive avgpool module. +template +struct AdaptiveAvgPoolOptions { + AdaptiveAvgPoolOptions(output_size_t output_size) + : output_size_(output_size) {} + + /// the target output size + TORCH_ARG(output_size_t, output_size); +}; + +/// `AdaptiveAvgPoolOptions` specialized for the `AdaptiveAvgPool1d` module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool1d model(AdaptiveAvgPool1dOptions(5)); +/// ``` +using AdaptiveAvgPool1dOptions = AdaptiveAvgPoolOptions>; + +/// `AdaptiveAvgPoolOptions` specialized for the `AdaptiveAvgPool2d` module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool2d model(AdaptiveAvgPool2dOptions({3, 2})); +/// ``` +using AdaptiveAvgPool2dOptions = + AdaptiveAvgPoolOptions>; + +/// `AdaptiveAvgPoolOptions` specialized for the `AdaptiveAvgPool3d` module. +/// +/// Example: +/// ``` +/// AdaptiveAvgPool3d model(AdaptiveAvgPool3dOptions(3)); +/// ``` +using AdaptiveAvgPool3dOptions = + AdaptiveAvgPoolOptions>; + +namespace functional { +/// Options for `torch::nn::functional::adaptive_avg_pool1d`. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool1dOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_avg_pool1d(x, F::AdaptiveAvgPool1dFuncOptions(3)); +/// ``` +using AdaptiveAvgPool1dFuncOptions = AdaptiveAvgPool1dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::adaptive_avg_pool2d`. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool2dOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_avg_pool2d(x, F::AdaptiveAvgPool2dFuncOptions(3)); +/// ``` +using AdaptiveAvgPool2dFuncOptions = AdaptiveAvgPool2dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::adaptive_avg_pool3d`. +/// +/// See the documentation for `torch::nn::AdaptiveAvgPool3dOptions` class to +/// learn what arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::adaptive_avg_pool3d(x, F::AdaptiveAvgPool3dFuncOptions(3)); +/// ``` +using AdaptiveAvgPool3dFuncOptions = AdaptiveAvgPool3dOptions; +} // namespace functional + +// ============================================================================ + +/// Options for a `D`-dimensional maxunpool module. +template +struct MaxUnpoolOptions { + MaxUnpoolOptions(ExpandingArray kernel_size) + : kernel_size_(kernel_size), stride_(kernel_size) {} + + /// the size of the window to take a max over + TORCH_ARG(ExpandingArray, kernel_size); + + /// the stride of the window. Default value is `kernel_size + TORCH_ARG(ExpandingArray, stride); + + /// implicit zero padding to be added on both sides + TORCH_ARG(ExpandingArray, padding) = 0; +}; + +/// `MaxUnpoolOptions` specialized for the `MaxUnpool1d` module. +/// +/// Example: +/// ``` +/// MaxUnpool1d model(MaxUnpool1dOptions(3).stride(2).padding(1)); +/// ``` +using MaxUnpool1dOptions = MaxUnpoolOptions<1>; + +/// `MaxUnpoolOptions` specialized for the `MaxUnpool2d` module. +/// +/// Example: +/// ``` +/// MaxUnpool2d model(MaxUnpool2dOptions(3).stride(2).padding(1)); +/// ``` +using MaxUnpool2dOptions = MaxUnpoolOptions<2>; + +/// `MaxUnpoolOptions` specialized for the `MaxUnpool3d` module. +/// +/// Example: +/// ``` +/// MaxUnpool3d model(MaxUnpool3dOptions(3).stride(2).padding(1)); +/// ``` +using MaxUnpool3dOptions = MaxUnpoolOptions<3>; + +// ============================================================================ + +namespace functional { + +/// Options for a `D`-dimensional maxunpool functional. +template +struct MaxUnpoolFuncOptions { + MaxUnpoolFuncOptions(ExpandingArray kernel_size) + : kernel_size_(kernel_size), stride_(kernel_size) {} + + /// the size of the window to take a max over + TORCH_ARG(ExpandingArray, kernel_size); + + /// the stride of the window. Default value is `kernel_size + TORCH_ARG(ExpandingArray, stride); + + /// implicit zero padding to be added on both sides + TORCH_ARG(ExpandingArray, padding) = 0; + + /// the targeted output size + TORCH_ARG(std::optional>, output_size) = std::nullopt; +}; + +/// `MaxUnpoolFuncOptions` specialized for +/// `torch::nn::functional::max_unpool1d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_unpool1d(x, indices, +/// F::MaxUnpool1dFuncOptions(3).stride(2).padding(1)); +/// ``` +using MaxUnpool1dFuncOptions = MaxUnpoolFuncOptions<1>; + +/// `MaxUnpoolFuncOptions` specialized for +/// `torch::nn::functional::max_unpool2d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_unpool2d(x, indices, +/// F::MaxUnpool2dFuncOptions(3).stride(2).padding(1)); +/// ``` +using MaxUnpool2dFuncOptions = MaxUnpoolFuncOptions<2>; + +/// `MaxUnpoolFuncOptions` specialized for +/// `torch::nn::functional::max_unpool3d`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::max_unpool3d(x, indices, F::MaxUnpool3dFuncOptions(3)); +/// ``` +using MaxUnpool3dFuncOptions = MaxUnpoolFuncOptions<3>; + +} // namespace functional + +// ============================================================================ + +/// Options for a `D`-dimensional fractional maxpool module. +template +struct FractionalMaxPoolOptions { + FractionalMaxPoolOptions(ExpandingArray kernel_size) + : kernel_size_(kernel_size) {} + + /// the size of the window to take a max over + TORCH_ARG(ExpandingArray, kernel_size); + + /// the target output size of the image + TORCH_ARG(std::optional>, output_size) = std::nullopt; + + /// If one wants to have an output size as a ratio of the input size, this + /// option can be given. This has to be a number or tuple in the range (0, 1) + using ExpandingArrayDouble = torch::ExpandingArray; + TORCH_ARG(std::optional, output_ratio) = std::nullopt; + + TORCH_ARG(torch::Tensor, _random_samples) = Tensor(); +}; + +/// `FractionalMaxPoolOptions` specialized for the `FractionalMaxPool2d` module. +/// +/// Example: +/// ``` +/// FractionalMaxPool2d model(FractionalMaxPool2dOptions(5).output_size(1)); +/// ``` +using FractionalMaxPool2dOptions = FractionalMaxPoolOptions<2>; + +/// `FractionalMaxPoolOptions` specialized for the `FractionalMaxPool3d` module. +/// +/// Example: +/// ``` +/// FractionalMaxPool3d model(FractionalMaxPool3dOptions(5).output_size(1)); +/// ``` +using FractionalMaxPool3dOptions = FractionalMaxPoolOptions<3>; + +namespace functional { +/// Options for `torch::nn::functional::fractional_max_pool2d` and +/// `torch::nn::functional::fractional_max_pool2d_with_indices` +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fractional_max_pool2d(x, +/// F::FractionalMaxPool2dFuncOptions(3).output_size(2)); +/// ``` +using FractionalMaxPool2dFuncOptions = FractionalMaxPool2dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::fractional_max_pool3d` and +/// `torch::nn::functional::fractional_max_pool3d_with_indices` +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::fractional_max_pool3d(x, +/// F::FractionalMaxPool3dFuncOptions(3).output_size(2)); +/// ``` +using FractionalMaxPool3dFuncOptions = FractionalMaxPool3dOptions; +} // namespace functional + +// ============================================================================ + +/// Options for a `D`-dimensional lppool module. +template +struct LPPoolOptions { + LPPoolOptions(double norm_type, ExpandingArray kernel_size) + : norm_type_(norm_type), + kernel_size_(kernel_size), + stride_(kernel_size) {} + + TORCH_ARG(double, norm_type); + + // the size of the window to take an average over + TORCH_ARG(ExpandingArray, kernel_size); + + // the stride of the window. Default value is `kernel_size` + TORCH_ARG(ExpandingArray, stride); + + // when True, will use `ceil` instead of `floor` to compute the output shape + TORCH_ARG(bool, ceil_mode) = false; +}; + +/// `LPPoolOptions` specialized for the `LPPool1d` module. +/// +/// Example: +/// ``` +/// LPPool1d model(LPPool1dOptions(1, 2).stride(5).ceil_mode(true)); +/// ``` +using LPPool1dOptions = LPPoolOptions<1>; + +/// `LPPoolOptions` specialized for the `LPPool2d` module. +/// +/// Example: +/// ``` +/// LPPool2d model(LPPool2dOptions(1, std::vector({3, 4})).stride({5, +/// 6}).ceil_mode(true)); +/// ``` +using LPPool2dOptions = LPPoolOptions<2>; + +/// `LPPoolOptions` specialized for the `LPPool3d` module. +/// +/// Example: +/// ``` +/// LPPool3d model(LPPool3dOptions(1, std::vector({3, 4, 5})).stride( +/// {5, 6, 7}).ceil_mode(true)); +/// ``` +using LPPool3dOptions = LPPoolOptions<3>; + +namespace functional { +/// Options for `torch::nn::functional::lp_pool1d`. +/// +/// See the documentation for `torch::nn::LPPool1dOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::lp_pool1d(x, F::LPPool1dFuncOptions(2, 3).stride(2)); +/// ``` +using LPPool1dFuncOptions = LPPool1dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::lp_pool2d`. +/// +/// See the documentation for `torch::nn::LPPool2dOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::lp_pool2d(x, F::LPPool2dFuncOptions(2, {2, 3}).stride(2)); +/// ``` +using LPPool2dFuncOptions = LPPool2dOptions; +} // namespace functional + +namespace functional { +/// Options for `torch::nn::functional::lp_pool3d`. +/// +/// See the documentation for `torch::nn::LPPool3dOptions` class to learn what +/// arguments are supported. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::lp_pool3d(x, F::LPPool3dFuncOptions(2, {2, 3, 4}).stride(2)); +/// ``` +using LPPool3dFuncOptions = LPPool3dOptions; +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/rnn.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/rnn.h new file mode 100644 index 0000000000000000000000000000000000000000..44d9b5ab6b61714211251cb323dec620d7076fc5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/rnn.h @@ -0,0 +1,234 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +namespace detail { + +/// Common options for RNN, LSTM and GRU modules. +struct TORCH_API RNNOptionsBase { + typedef std::variant< + enumtype::kLSTM, + enumtype::kGRU, + enumtype::kRNN_TANH, + enumtype::kRNN_RELU> + rnn_options_base_mode_t; + + RNNOptionsBase( + rnn_options_base_mode_t mode, + int64_t input_size, + int64_t hidden_size); + + TORCH_ARG(rnn_options_base_mode_t, mode); + /// The number of features of a single sample in the input sequence `x`. + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h`. + TORCH_ARG(int64_t, hidden_size); + /// The number of recurrent layers (cells) to use. + TORCH_ARG(int64_t, num_layers) = 1; + /// Whether a bias term should be added to all linear operations. + TORCH_ARG(bool, bias) = true; + /// If true, the input sequence should be provided as `(batch, sequence, + /// features)`. If false (default), the expected layout is `(sequence, batch, + /// features)`. + TORCH_ARG(bool, batch_first) = false; + /// If non-zero, adds dropout with the given probability to the output of each + /// RNN layer, except the final layer. + TORCH_ARG(double, dropout) = 0.0; + /// Whether to make the RNN bidirectional. + TORCH_ARG(bool, bidirectional) = false; + /// Cell projection dimension. If 0, projections are not added. Can only be + /// used for LSTMs. + TORCH_ARG(int64_t, proj_size) = 0; +}; + +} // namespace detail + +/// Options for the `RNN` module. +/// +/// Example: +/// ``` +/// RNN model(RNNOptions(128, +/// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); +/// ``` +struct TORCH_API RNNOptions { + typedef std::variant nonlinearity_t; + + RNNOptions(int64_t input_size, int64_t hidden_size); + + /// The number of expected features in the input `x` + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h` + TORCH_ARG(int64_t, hidden_size); + /// Number of recurrent layers. E.g., setting ``num_layers=2`` + /// would mean stacking two RNNs together to form a `stacked RNN`, + /// with the second RNN taking in outputs of the first RNN and + /// computing the final results. Default: 1 + TORCH_ARG(int64_t, num_layers) = 1; + /// The non-linearity to use. Can be either ``torch::kTanh`` or + /// ``torch::kReLU``. Default: ``torch::kTanh`` + TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh; + /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. + /// Default: ``true`` + TORCH_ARG(bool, bias) = true; + /// If ``true``, then the input and output tensors are provided + /// as `(batch, seq, feature)`. Default: ``false`` + TORCH_ARG(bool, batch_first) = false; + /// If non-zero, introduces a `Dropout` layer on the outputs of each + /// RNN layer except the last layer, with dropout probability equal to + /// `dropout`. Default: 0 + TORCH_ARG(double, dropout) = 0.0; + /// If ``true``, becomes a bidirectional RNN. Default: ``false`` + TORCH_ARG(bool, bidirectional) = false; +}; + +/// Options for the `LSTM` module. +/// +/// Example: +/// ``` +/// LSTM model(LSTMOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); +/// ``` +struct TORCH_API LSTMOptions { + LSTMOptions(int64_t input_size, int64_t hidden_size); + + /// The number of expected features in the input `x` + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h` + TORCH_ARG(int64_t, hidden_size); + /// Number of recurrent layers. E.g., setting ``num_layers=2`` + /// would mean stacking two LSTMs together to form a `stacked LSTM`, + /// with the second LSTM taking in outputs of the first LSTM and + /// computing the final results. Default: 1 + TORCH_ARG(int64_t, num_layers) = 1; + /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. + /// Default: ``true`` + TORCH_ARG(bool, bias) = true; + /// If ``true``, then the input and output tensors are provided + /// as (batch, seq, feature). Default: ``false`` + TORCH_ARG(bool, batch_first) = false; + /// If non-zero, introduces a `Dropout` layer on the outputs of each + /// LSTM layer except the last layer, with dropout probability equal to + /// `dropout`. Default: 0 + TORCH_ARG(double, dropout) = 0.0; + /// If ``true``, becomes a bidirectional LSTM. Default: ``false`` + TORCH_ARG(bool, bidirectional) = false; + /// Cell projection dimension. If 0, projections are not added + TORCH_ARG(int64_t, proj_size) = 0; +}; + +/// Options for the `GRU` module. +/// +/// Example: +/// ``` +/// GRU model(GRUOptions(2, +/// 4).num_layers(3).batch_first(false).bidirectional(true)); +/// ``` +struct TORCH_API GRUOptions { + GRUOptions(int64_t input_size, int64_t hidden_size); + + /// The number of expected features in the input `x` + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h` + TORCH_ARG(int64_t, hidden_size); + /// Number of recurrent layers. E.g., setting ``num_layers=2`` + /// would mean stacking two GRUs together to form a `stacked GRU`, + /// with the second GRU taking in outputs of the first GRU and + /// computing the final results. Default: 1 + TORCH_ARG(int64_t, num_layers) = 1; + /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. + /// Default: ``true`` + TORCH_ARG(bool, bias) = true; + /// If ``true``, then the input and output tensors are provided + /// as (batch, seq, feature). Default: ``false`` + TORCH_ARG(bool, batch_first) = false; + /// If non-zero, introduces a `Dropout` layer on the outputs of each + /// GRU layer except the last layer, with dropout probability equal to + /// `dropout`. Default: 0 + TORCH_ARG(double, dropout) = 0.0; + /// If ``true``, becomes a bidirectional GRU. Default: ``false`` + TORCH_ARG(bool, bidirectional) = false; +}; + +namespace detail { + +/// Common options for RNNCell, LSTMCell and GRUCell modules +struct TORCH_API RNNCellOptionsBase { + RNNCellOptionsBase( + int64_t input_size, + int64_t hidden_size, + bool bias, + int64_t num_chunks); + TORCH_ARG(int64_t, input_size); + TORCH_ARG(int64_t, hidden_size); + TORCH_ARG(bool, bias); + TORCH_ARG(int64_t, num_chunks); +}; + +} // namespace detail + +/// Options for the `RNNCell` module. +/// +/// Example: +/// ``` +/// RNNCell model(RNNCellOptions(20, +/// 10).bias(false).nonlinearity(torch::kReLU)); +/// ``` +struct TORCH_API RNNCellOptions { + typedef std::variant nonlinearity_t; + + RNNCellOptions(int64_t input_size, int64_t hidden_size); + + /// The number of expected features in the input `x` + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h` + TORCH_ARG(int64_t, hidden_size); + /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. + /// Default: ``true`` + TORCH_ARG(bool, bias) = true; + /// The non-linearity to use. Can be either ``torch::kTanh`` or + /// ``torch::kReLU``. Default: ``torch::kTanh`` + TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh; +}; + +/// Options for the `LSTMCell` module. +/// +/// Example: +/// ``` +/// LSTMCell model(LSTMCellOptions(20, 10).bias(false)); +/// ``` +struct TORCH_API LSTMCellOptions { + LSTMCellOptions(int64_t input_size, int64_t hidden_size); + + /// The number of expected features in the input `x` + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h` + TORCH_ARG(int64_t, hidden_size); + /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. + /// Default: ``true`` + TORCH_ARG(bool, bias) = true; +}; + +/// Options for the `GRUCell` module. +/// +/// Example: +/// ``` +/// GRUCell model(GRUCellOptions(20, 10).bias(false)); +/// ``` +struct TORCH_API GRUCellOptions { + GRUCellOptions(int64_t input_size, int64_t hidden_size); + + /// The number of expected features in the input `x` + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h` + TORCH_ARG(int64_t, hidden_size); + /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. + /// Default: ``true`` + TORCH_ARG(bool, bias) = true; +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformer.h new file mode 100644 index 0000000000000000000000000000000000000000..a5ecba9d22637c24fd3b375d9ac4397156dc47c2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformer.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::nn { + +/// Options for the `Transformer` module +/// +/// Example: +/// ``` +/// TransformerOptions options; +/// TransformerOptions options(16, 4); +/// auto options = TransformerOptions().d_model(4).nhead(2).dropout(0.0); +/// ``` +struct TORCH_API TransformerOptions { + // The following constructors are commonly used + // Please don't add more unless it is proved as a common usage + TransformerOptions() = default; + TransformerOptions(int64_t d_model, int64_t nhead); + TransformerOptions( + int64_t d_model, + int64_t nhead, + int64_t num_encoder_layers, + int64_t num_decoder_layers); + + /// the number of expected features in the encoder/decoder inputs + /// (default=512) + TORCH_ARG(int64_t, d_model) = 512; + + /// the number of heads in the multiheadattention models (default=8) + TORCH_ARG(int64_t, nhead) = 8; + + /// the number of sub-encoder-layers in the encoder (default=6) + TORCH_ARG(int64_t, num_encoder_layers) = 6; + + /// the number of sub-decoder-layers in the decoder (default=6) + TORCH_ARG(int64_t, num_decoder_layers) = 6; + + /// the dimension of the feedforward network model (default=2048) + TORCH_ARG(int64_t, dim_feedforward) = 2048; + + /// the dropout value (default=0.1) + TORCH_ARG(double, dropout) = 0.1; + + /// the activation function of encoder/decoder intermediate layer + /// (default=``torch::kReLU``) + TORCH_ARG(activation_t, activation) = torch::kReLU; + + /// custom encoder (default=None) + TORCH_ARG(AnyModule, custom_encoder); + + /// custom decoder (default=None) + TORCH_ARG(AnyModule, custom_decoder); +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformercoder.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformercoder.h new file mode 100644 index 0000000000000000000000000000000000000000..343cce605b60f311d603d72f62e2189250989022 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformercoder.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::nn { + +/// Options for the `TransformerEncoder` +/// +/// Example: +/// ``` +/// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, +/// 8).dropout(0.1)); auto options = TransformerEncoderOptions(encoderLayer, +/// 6).norm(LayerNorm(LayerNormOptions({2}))); +/// ``` +struct TORCH_API TransformerEncoderOptions { + // This constructor will keep a shallow copy of encoder_layer, so it keeps all + // the data in encoder_layer. + TransformerEncoderOptions( + TransformerEncoderLayer encoder_layer, + int64_t num_layers); + // This constructor will create a new TransformerEncoderLayer obj based on + // passed in encoder_layer_options. + TransformerEncoderOptions( + const TransformerEncoderLayerOptions& encoder_layer_options, + int64_t num_layers); + + /// transformer Encoder Layer + TORCH_ARG(TransformerEncoderLayer, encoder_layer) = nullptr; + + /// number of encoder layers + TORCH_ARG(int64_t, num_layers); + + /// normalization module + TORCH_ARG(AnyModule, norm); +}; + +/// Options for the `TransformerDecoder` module. +/// +/// Example: +/// ``` +/// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512, +/// 8).dropout(0.1)); auto options = TransformerDecoderOptions(decoder_layer, +/// 6)norm(LayerNorm(LayerNormOptions({2}))); TransformerDecoder +/// transformer_decoder(options); +/// ``` +struct TORCH_API TransformerDecoderOptions { + // This constructor will keep the a ref of passed in decoder_layer, + // so it keeps all the data in decoder_layer. + TransformerDecoderOptions( + TransformerDecoderLayer decoder_layer, + int64_t num_layers); + // This constructor will create a new TransformerDecoderLayer obj, + // based on passed in decoder_layer_options. + TransformerDecoderOptions( + const TransformerDecoderLayerOptions& decoder_layer_options, + int64_t num_layers); + + /// decoder layer to be cloned + TORCH_ARG(TransformerDecoderLayer, decoder_layer) = nullptr; + + /// number of decoder layers + TORCH_ARG(int64_t, num_layers); + + /// normalization module + TORCH_ARG(AnyModule, norm); +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformerlayer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformerlayer.h new file mode 100644 index 0000000000000000000000000000000000000000..d20f60567b9e2e8e26890aaa5e4998bd9c086afd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/transformerlayer.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn { + +using activation_t = std::variant< + enumtype::kReLU, + enumtype::kGELU, + std::function>; + +/// Options for the `TransformerEncoderLayer` +/// +/// Example: +/// ``` +/// auto options = TransformerEncoderLayer(512, 8).dropout(0.2); +/// ``` +struct TORCH_API TransformerEncoderLayerOptions { + /* implicit */ TransformerEncoderLayerOptions(int64_t d_model, int64_t nhead); + + /// the number of expected features in the input + TORCH_ARG(int64_t, d_model); + + /// the number of heads in the multiheadattention models + TORCH_ARG(int64_t, nhead); + + /// the dimension of the feedforward network model, default is 2048 + TORCH_ARG(int64_t, dim_feedforward) = 2048; + + /// the dropout value, default is 0.1 + TORCH_ARG(double, dropout) = 0.1; + + /// the activation function of intermediate layer, can be ``torch::kReLU``, + /// ``torch::GELU``, or a unary callable. Default: ``torch::kReLU`` + TORCH_ARG(activation_t, activation) = torch::kReLU; +}; + +// ============================================================================ + +/// Options for the `TransformerDecoderLayer` module. +/// +/// Example: +/// ``` +/// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512, +/// 8).dropout(0.2)); +/// ``` +struct TORCH_API TransformerDecoderLayerOptions { + TransformerDecoderLayerOptions(int64_t d_model, int64_t nhead); + + /// number of expected features in the input + TORCH_ARG(int64_t, d_model); + + /// number of heads in the multiheadattention models + TORCH_ARG(int64_t, nhead); + + /// dimension of the feedforward network model. Default: 2048 + TORCH_ARG(int64_t, dim_feedforward) = 2048; + + /// dropout value. Default: 1 + TORCH_ARG(double, dropout) = 0.1; + + /// activation function of intermediate layer, can be ``torch::kGELU``, + /// ``torch::kReLU``, or a unary callable. Default: ``torch::kReLU`` + TORCH_ARG(activation_t, activation) = torch::kReLU; +}; + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/upsampling.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/upsampling.h new file mode 100644 index 0000000000000000000000000000000000000000..a0d6bb57182c40e1bd251dd9c364fee91496b915 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/upsampling.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch::nn { + +/// Options for the `Upsample` module. +/// +/// Example: +/// ``` +/// Upsample +/// model(UpsampleOptions().scale_factor(std::vector({3})).mode(torch::kLinear).align_corners(false)); +/// ``` +struct TORCH_API UpsampleOptions { + /// output spatial sizes. + TORCH_ARG(std::optional>, size) = std::nullopt; + + /// multiplier for spatial size. + TORCH_ARG(std::optional>, scale_factor) = std::nullopt; + + /// the upsampling algorithm: one of "nearest", "linear", "bilinear", + /// "bicubic" and "trilinear". Default: "nearest" + typedef std::variant< + enumtype::kNearest, + enumtype::kLinear, + enumtype::kBilinear, + enumtype::kBicubic, + enumtype::kTrilinear> + mode_t; + TORCH_ARG(mode_t, mode) = torch::kNearest; + + /// if "True", the corner pixels of the input and output tensors are + /// aligned, and thus preserving the values at those pixels. This only has + /// effect when :attr:`mode` is "linear", "bilinear", "bicubic", or + /// "trilinear". Default: "False" + TORCH_ARG(std::optional, align_corners) = std::nullopt; +}; + +namespace functional { + +/// Options for `torch::nn::functional::interpolate`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::interpolate(input, +/// F::InterpolateFuncOptions().size(std::vector({4})).mode(torch::kNearest)); +/// ``` +struct TORCH_API InterpolateFuncOptions { + typedef std::variant< + enumtype::kNearest, + enumtype::kLinear, + enumtype::kBilinear, + enumtype::kBicubic, + enumtype::kTrilinear, + enumtype::kArea, + enumtype::kNearestExact> + mode_t; + + /// output spatial sizes. + TORCH_ARG(std::optional>, size) = std::nullopt; + + /// multiplier for spatial size. + TORCH_ARG(std::optional>, scale_factor) = std::nullopt; + + /// the upsampling algorithm: one of "nearest", "linear", "bilinear", + /// "bicubic", "trilinear", "area", "nearest-exact". Default: "nearest" + TORCH_ARG(mode_t, mode) = torch::kNearest; + + /// Geometrically, we consider the pixels of the input and output as squares + /// rather than points. If set to "True", the input and output tensors are + /// aligned by the center points of their corner pixels, preserving the values + /// at the corner pixels. If set to "False", the input and output tensors + /// are aligned by the corner points of their corner pixels, and the + /// interpolation uses edge value padding for out-of-boundary values, making + /// this operation *independent* of input size when `scale_factor` is + /// kept the same. It is *required* when interpolating mode is "linear", + /// "bilinear", "bicubic" or "trilinear". Default: "False" + TORCH_ARG(std::optional, align_corners) = std::nullopt; + + /// recompute the scale_factor for use in the + /// interpolation calculation. When `scale_factor` is passed as a parameter, + /// it is used to compute the `output_size`. If `recompute_scale_factor` is + /// `true` or not specified, a new `scale_factor` will be computed based on + /// the output and input sizes for use in the interpolation computation (i.e. + /// the computation will be identical to if the computed `output_size` were + /// passed-in explicitly). Otherwise, the passed-in `scale_factor` will be + /// used in the interpolation computation. Note that when `scale_factor` is + /// floating-point, the recomputed scale_factor may differ from the one passed + /// in due to rounding and precision issues. + TORCH_ARG(std::optional, recompute_scale_factor) = std::nullopt; + + /// flag to apply anti-aliasing. Using anti-alias + /// option together with :attr:`align_corners` equals "False", interpolation + /// result would match Pillow result for downsampling operation. Supported + /// modes: "bilinear". Default: "False". + TORCH_ARG(bool, antialias) = false; +}; + +} // namespace functional + +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/vision.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/vision.h new file mode 100644 index 0000000000000000000000000000000000000000..189aee3c5fc12f2d9e396e005d8416786ac2771a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/options/vision.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::nn::functional { + +/// Options for `torch::nn::functional::grid_sample`. +/// +/// Example: +/// ``` +/// namespace F = torch::nn::functional; +/// F::grid_sample(input, grid, +/// F::GridSampleFuncOptions().mode(torch::kBilinear).padding_mode(torch::kZeros).align_corners(true)); +/// ``` +struct TORCH_API GridSampleFuncOptions { + typedef std:: + variant + mode_t; + typedef std:: + variant + padding_mode_t; + + /// interpolation mode to calculate output values. Default: Bilinear + TORCH_ARG(mode_t, mode) = torch::kBilinear; + /// padding mode for outside grid values. Default: Zeros + TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; + /// Specifies perspective to pixel as point. Default: false + TORCH_ARG(std::optional, align_corners) = std::nullopt; +}; + +} // namespace torch::nn::functional diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/parallel/data_parallel.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/parallel/data_parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..58916c861523c74a9db4c48d3d36dec0689f3a75 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/parallel/data_parallel.h @@ -0,0 +1,295 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace torch::nn { + +namespace { + +// Note [Replicating Modules] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Module replication is implemented in the following two steps: +// 1) create a module replica on each destination device using Module.clone(). +// 2) manually add a gradient edge pointing from every parameter X in every +// module replica to the same parameter X in the original module, using +// ReduceAdd as the grad_fn. +// +// ReduceAdd can ONLY be used during the backward pass of data parallel. Forward +// pass cannot use this function as it does not setup gradient function and +// history at all. Do NOT try to use ReduceAdd for any other purposes. +// +// NB: An alternative is to add Broadcast and ReduceAddCoalesce to +// torch/csrc/autograd/functions/comm.cpp as normal autograd functions, +// implement a Replicatable (like cloneable) class and add it as a friend class +// in Module.h. In the forward pass, the Replicatable could use the Broadcast +// function to replicate every module parameter and set gradient functions using +// ReduceAddCoalesce (like how it is implemented in Python). However, unlike in +// Python, where changes to Linear._parameters["weight"] would also apply to +// Linear.weight (using Linear as an example), Linear.weight and +// Linear.parameters_["weight"] are two tensor objects pointing to the same +// TensorImpl. Assigning a new tensor to Linear.parameters_["weight"] will not +// change Linear.weight. To make this work, we will have to: +// 1) force every module to also inherit from Replicatable +// 2) force every module to implement an additional function, e.g., +// Replicatable::load_params(), to pick up changes from parameters_ to their +// own member fields. +// This will be an overkill as Replicatable will only be used in data_parallel, +// not even ddp. + +// Autograd function for the replicate step in data parallel. This is only used +// in data parallel, and should not be exposed as a user API. +struct ReduceAdd : public autograd::Node { + explicit ReduceAdd(const at::Device& destination_device) + : destination_device_(destination_device) {}; + ~ReduceAdd() override = default; + + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + autograd::variable_list apply(autograd::variable_list&& inputs) override { + TORCH_CHECK( + !torch::autograd::compute_requires_grad(inputs), + "ReduceAdd can only be used during the backward pass of data parallel."); + + Tensor output = torch::zeros_like(inputs[0], {destination_device_}); + + for (auto& input : inputs) { + TORCH_CHECK( + input.sizes() == inputs[0].sizes(), + "All inputs of ReduceAdd must have the same size, but got ", + input.sizes(), + " and ", + inputs[0].sizes()); + + TORCH_CHECK( + input.dtype() == inputs[0].dtype(), + "All inputs of ReduceAdd must have the same dtype, but got ", + input.dtype(), + " and ", + inputs[0].dtype()); + + // TODO: use nccl reduce + output.add_(input.to(destination_device_)); + } + + return {output}; + } + + private: + at::Device destination_device_; +}; + +} // namespace + +// A friend function to Module, it recursively sets gradient edges pointing from +// every parameter X in every module replica to the same parameter X in the +// original module. See [Replicating Modules] +template +void replicate_grad_edges( + const std::shared_ptr& module, + const std::vector>& replicas, + const std::vector& devices) { + for (auto& parameter : module->named_parameters(/*recurse=*/false)) { + auto grad_fn = std::make_shared((*parameter).device()); + grad_fn->set_next_edges(autograd::collect_next_edges(*parameter)); + + for (const auto i : c10::irange(devices.size())) { + autograd::set_history(replicas[i]->parameters_[parameter.key()], grad_fn); + } + } + + for (auto& buffer : module->named_buffers(/*recurse=*/false)) { + if (buffer.value().requires_grad()) { + auto grad_fn = std::make_shared((*buffer).device()); + grad_fn->set_next_edges(autograd::collect_next_edges(*buffer)); + + for (const auto i : c10::irange(devices.size())) { + autograd::set_history(replicas[i]->buffers_[buffer.key()], grad_fn); + } + } + } + + for (auto& child : module->children_) { + std::vector> child_replicas; + child_replicas.reserve(devices.size()); + for (auto& replica : replicas) { + child_replicas.push_back(replica->children_[child.key()]); + } + + // recursively set gradient edges for all children + replicate_grad_edges(*child, child_replicas, devices); + } +} + +namespace parallel { + +/// Replicates a module on the given list of devices. +/// A replica is created by calling `clone()` on the module. For this, the +/// module must inherit from `nn::Cloneable`, or define its own `clone()` +/// method, which is expected to perform a deep copy of the module. +template +std::vector> replicate( + const std::shared_ptr& module, + const std::vector& devices) { + std::vector> replicas; + replicas.reserve(devices.size()); + for (const auto& device : devices) { + replicas.push_back( + std::dynamic_pointer_cast(module->clone(device))); + } + // Configure gradient edges to point from replcia parameters to original + // module parameters. See [Replicating Modules] + replicate_grad_edges(module, replicas, devices); + return replicas; +} + +/// Replicates a module holder on the given list of devices. +/// This method allows calling `replicate()` with a module holder, such as +/// `Linear`. +template +std::vector> replicate( + const ModuleHolder& module, + const std::vector& devices) { + auto ptrs = replicate(module.ptr(), devices); + return std::vector>(ptrs.begin(), ptrs.end()); +} + +/// Applies the given inputs to the given modules in a parallel fashion. +/// Conceptually, a thread is spawned for each `(module, input)` pair, in which +/// `forward()` is called on the module with its corresponding input. The +/// outputs of the individual calls are stored in a vector and returned. +/// +/// The first exception caught by any thread is stashed and rethrown after all +/// threads have completed their operation. +/// +/// Further remarks: +/// 1. The length of the module container must match the length of the inputs. +/// 2. If a list of devices is supplied, it must match the list of modules in +/// length. Each device will be set to the current default device during the +/// invocation of the respective module. This means any tensors allocated on the +/// default device inside the module will be constructed on this device. +template +std::vector parallel_apply( + std::vector& modules, + const std::vector& inputs, + const std::optional>& devices = std::nullopt) { + TORCH_CHECK( + modules.size() == inputs.size(), "Must have as many inputs as modules"); + if (devices) { + TORCH_CHECK( + modules.size() == devices->size(), + "Must have as many devices as modules"); + } + + std::vector outputs(modules.size()); + std::mutex mutex; + + // std::exception_ptr can be passed between threads: + // > An instance of std::exception_ptr may be passed to another function, + // > possibly on another thread, where the exception may be rethrown [...]. + // https://en.cppreference.com/w/cpp/error/exception_ptr + std::exception_ptr exception; + + at::parallel_for( + /*begin=*/0, + /*end=*/modules.size(), + /*grain_size=*/1, + [&modules, &inputs, &devices, &outputs, &mutex, &exception]( + int64_t index, int64_t stop) { + for (; index < stop; ++index) { + try { + auto output = modules[index]->forward(inputs[index]); + output = + output.to(devices ? (*devices)[index] : inputs[index].device()); + std::lock_guard lock(mutex); + outputs[index] = output; + } catch (...) { + std::lock_guard lock(mutex); + if (!exception) { + exception = std::current_exception(); + } + } + } + }); + + if (exception) { + std::rethrow_exception(exception); + } + + return outputs; +} + +/// Evaluates `module(input)` in parallel across the given `devices`. If +/// `devices` is not supplied, the invocation is parallelized across all +/// available CUDA devices. If `output_device` is supplied, the final, combined +/// tensor will be placed on this device. If not, it defaults to the first +/// device in `devices`. +/// +/// In detail, this method performs the following four distinct steps: +/// 1. *Scatter* the input to the given devices, +/// 2. *Replicate* (deep clone) the model on each device, +/// 3. *Evaluate* each module with its input on its device, +/// 4. *Gather* the outputs of each replica into a single output tensor, located +/// on the `output_device`. +template +Tensor data_parallel( + ModuleType module, + Tensor input, + std::optional> devices = std::nullopt, + std::optional output_device = std::nullopt, + int64_t dim = 0) { + if (!devices) { + const auto device_count = torch::cuda::device_count(); + TORCH_CHECK( + device_count > 0, "Expected at least one CUDA device to be available"); + devices = std::vector(); + devices->reserve(device_count); + for (const auto index : c10::irange(device_count)) { + devices->emplace_back(kCUDA, static_cast(index)); + } + } + if (!output_device) { + output_device = devices->front(); + } + + if (devices->size() == 1) { + module->to(devices->front()); + input = input.to(devices->front()); + return module->forward(std::move(input)).to(*output_device); + } + + autograd::Scatter scatter(*devices, /*chunk_sizes=*/std::nullopt, dim); + auto scattered_inputs = fmap(scatter.apply({std::move(input)})); + // Input tensor might not be big enough to scale across all available devices + if (scattered_inputs.size() < devices->size()) { + devices->resize( + scattered_inputs.size(), + Device(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)); + } + + auto replicas = replicate(module, *devices); + auto outputs = parallel_apply(replicas, scattered_inputs, *devices); + return autograd::Gather(*output_device, dim) + .apply(fmap(std::move(outputs))) + .front(); +} + +} // namespace parallel +} // namespace torch::nn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/pimpl-inl.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/pimpl-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..cea53b6562bd1d1cec21a7d6ed520fd271ae4c50 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/pimpl-inl.h @@ -0,0 +1,76 @@ +// This class exists only to do SFINAE on abstract types `T` that are really +// `ModuleHolder`, because there's no good way to say that `T` is a +// `ModuleHolder` over some unknown type `ModuleType`. With this, you can do +// `enable_if_t>`. +struct ModuleHolderIndicator {}; + +// A type trait that is true for types that are `ModuleHolder`s. +template +using is_module_holder = + std::is_base_of>; + +template +using disable_if_module_holder_t = + std::enable_if_t::value>; + +// A collection of templates that answer the question whether a type `T` is a +// `ModuleHolder`, and if so whether its contained type is of type `C`. This is +// tricky because it is hard to short circuit in template metaprogramming. A +// naive and incorrect solution to this problem would be something like +// `disable_if::value && typename T::ContainedType == C>`. +// This would disable all types that are not `ModuleHolder`s, because even +// though the `is_module_holder::value` may be `false` for such types the +// `T::ContainedType` access would be ill-formed and thus fail the whole +// expression by the rules of SFINAE. Instead we have to use template +// specialization to statically branch on the first condition +// (`is_module_holder`) and are only then allowed to query +// `T::ContainedType` in the branch for which the condition was true. + +// Base template. +template +struct is_module_holder_of_impl; + +// False branch. `T` is not a `ModuleHolder` and thus not a `ModuleHolder` with +// contained type `C`. +template +struct is_module_holder_of_impl : std::false_type {}; + +// True branch. `T` is a `ModuleHolder` and thus we can legit access its +// `ContainedType` and compare it against `C`. +template +struct is_module_holder_of_impl + : std::is_same {}; + +// Helper template. +template +struct is_module_holder_of : is_module_holder_of_impl< + is_module_holder::value, + std::decay_t, + std::decay_t> {}; + +// A collection of templates that allow deducing the return type of the +// `forward()` method, but only if a module actually has a `forward()` method, +// and otherwise deduces to the type `void`. + +template +struct return_type_of_forward_impl; + +template +struct return_type_of_forward_impl { + using type = decltype(::std::declval().forward(::std::declval()...)); +}; + +template +struct return_type_of_forward_impl { + using type = void; +}; + +template +using return_type_of_forward = return_type_of_forward_impl< + torch::detail::has_forward::value, + C, + Args...>; + +template +using return_type_of_forward_t = + typename return_type_of_forward::type; diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/pimpl.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/pimpl.h new file mode 100644 index 0000000000000000000000000000000000000000..3c1206e4edb82150929d41a57c91902e360321ad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/pimpl.h @@ -0,0 +1,200 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace torch { +namespace detail { +// Dump all the template metaprogramming in this file. +#include +} // namespace detail + +namespace nn { + +/// A `ModuleHolder` is essentially a wrapper around `std::shared_ptr` where +/// `M` is an `nn::Module` subclass, with convenient constructors defined for +/// the kind of constructions we want to allow for our modules. +template +class ModuleHolder : torch::detail::ModuleHolderIndicator { + protected: + /// The module pointer this class wraps. + /// NOTE: Must be placed at the top of the class so that we can use it with + /// trailing return types below. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::shared_ptr impl_; + + public: + using ContainedType = Contained; + + /// Default constructs the contained module if if has a default constructor, + /// else produces a static error. + /// + /// NOTE: This uses the behavior of template + /// classes in C++ that constructors (or any methods) are only compiled when + /// actually used. + ModuleHolder() : impl_(default_construct()) { + static_assert( + std::is_default_constructible_v, + "You are trying to default construct a module which has " + "no default constructor. Use = nullptr to give it the empty state " + "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`)."); + } + + /// Constructs the `ModuleHolder` with an empty contained value. Access to + /// the underlying module is not permitted and will throw an exception, until + /// a value is assigned. + /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {} + + /// Constructs the `ModuleHolder` with a contained module, forwarding all + /// arguments to its constructor. + template < + typename Head, + typename... Tail, + typename = std::enable_if_t< + !(torch::detail::is_module_holder_of::value && + (sizeof...(Tail) == 0))>> + explicit ModuleHolder(Head&& head, Tail&&... tail) + : impl_(new Contained( + std::forward(head), + std::forward(tail)...)) {} + + /// Constructs the `ModuleHolder` from a pointer to the contained type. + /// Example: `Linear(std::make_shared(...))`. + /* implicit */ ModuleHolder(std::shared_ptr module) + : impl_(std::move(module)) {} + + /// Returns true if the `ModuleHolder` contains a module, or false if it is + /// `nullptr`. + explicit operator bool() const noexcept { + return !is_empty(); + } + + /// Forwards to the contained module. + Contained* operator->() { + return get(); + } + + /// Forwards to the contained module. + const Contained* operator->() const { + return get(); + } + + /// Returns a reference to the contained module. + Contained& operator*() { + return *get(); + } + + /// Returns a const reference to the contained module. + const Contained& operator*() const { + return *get(); + } + + /// Returns a shared pointer to the underlying module. + const std::shared_ptr& ptr() const { + TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); + return impl_; + } + + /// Returns a pointer to the underlying module. + Contained* get() { + TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); + return impl_.get(); + } + + /// Returns a const pointer to the underlying module. + const Contained* get() const { + TORCH_CHECK(!is_empty(), "Accessing empty ModuleHolder"); + return impl_.get(); + } + + /// Calls the `forward()` method of the contained module. + template + auto operator()(Args&&... args) + -> torch::detail::return_type_of_forward_t { + // This will not compile if the module does not have a `forward()` method + // (as expected). + // NOTE: `std::forward` is qualified to prevent VS2017 emitting + // error C2872: 'std': ambiguous symbol + return impl_->forward(::std::forward(args)...); + } + + /// Forwards to the subscript operator of the contained module. + /// NOTE: std::forward is qualified to prevent VS2017 emitting + /// error C2872: 'std': ambiguous symbol + template + decltype(auto) operator[](Arg&& arg) { + return (*impl_)[::std::forward(arg)]; + } + + /// Returns true if the `ModuleHolder` does not contain a module. + bool is_empty() const noexcept { + return impl_ == nullptr; + } + + private: + template + std::shared_ptr default_construct() { + if constexpr (std::is_default_constructible_v) { + return std::make_shared(); + } else { + return nullptr; + } + } +}; + +/// Pretty prints the given `Module` into the `ostream`. +template +std::ostream& operator<<( + std::ostream& stream, + const nn::ModuleHolder& module) { + return stream << *module; +} + +/// Serializes a `ModuleHolder` into an `OutputArchive`. +template +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const nn::ModuleHolder& module) { + return archive << module.ptr(); +} + +/// Deserializes a `ModuleHolder` from an `InputArchive`. +template +serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + nn::ModuleHolder& module) { + return archive >> module.ptr(); +} + +} // namespace nn +} // namespace torch + +// Workaround for CUDA 10.2 and below not allowing attribute unused on +// using declarations. +#ifdef __CUDACC__ +#define TORCH_UNUSED_EXCEPT_CUDA +#else +#define TORCH_UNUSED_EXCEPT_CUDA [[maybe_unused]] +#endif + +/// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a +/// wrapper over a `std::shared_ptr`. +/// `Impl` is a type alias for `ImplType` which provides a way to call static +/// method of `ImplType`. +#define TORCH_MODULE_IMPL(Name, ImplType) \ + class Name : public torch::nn::ModuleHolder { /* NOLINT */ \ + public: \ + using torch::nn::ModuleHolder::ModuleHolder; \ + using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType; \ + } + +/// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `Impl`. +#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl) diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..8dbfaf5126e4f3db94174937432ea4b017354ab7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils.h @@ -0,0 +1,5 @@ +#pragma once + +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/clip_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..a5fbbcbd854cdd8ec9f9358ef9c6328d49cf132f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/clip_grad.h @@ -0,0 +1,144 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch::nn::utils { + +// Clips gradient norm of a vector of Tensors. +// See +// https://pytorch.org/docs/stable/nn.html?highlight=clip_grad_norm#torch.nn.utils.clip_grad_norm_ +// for more details about this module. +// +// Difference with the python version: unlike the python version, even when +// skipping the finiteness checks (error_if_nonfinite = false), this function +// will introduce a device <=> CPU synchronization (for devices where that makes +// sense!) in order to return a CPU-side `double`. This C++ version therefore +// cannot be run fully asynchronously w.r.t. the device of the gradients. +inline double clip_grad_norm_( + const std::vector& parameters, + double max_norm, + double norm_type = 2.0, + bool error_if_nonfinite = false) { + std::vector params_with_grad; + + for (const auto& param : parameters) { + auto& grad = param.grad(); + if (grad.defined()) { + params_with_grad.push_back(param); + } + } + + if (params_with_grad.empty()) { + return 0.0; + } + + Tensor total_norm_tensor; + if (norm_type == std::numeric_limits::infinity()) { + std::vector norms; + norms.reserve(params_with_grad.size()); + + for (const auto& param : params_with_grad) { + norms.emplace_back(param.grad().data().abs().max()); + } + total_norm_tensor = + (norms.size() == 1) ? norms[0] : torch::max(torch::stack(norms)); + } else if (norm_type == 0) { + total_norm_tensor = + torch::full({}, static_cast(params_with_grad.size())); + } else { + std::vector norms; + norms.reserve(params_with_grad.size()); + + for (const auto& param : params_with_grad) { + norms.emplace_back(param.grad().data().norm(norm_type)); + } + total_norm_tensor = + (norms.size() == 1) ? norms[0] : torch::stack(norms).norm(norm_type); + } + + // When possible (ie when skipping the finiteness check), we avoid + // synchronizing the CPU and the gradients' device until the very end to + // preserve async execution on the device. When checking for finite-ness, this + // optional ensures we only sync once. + std::optional total_norm = std::nullopt; + if (error_if_nonfinite) { + total_norm = total_norm_tensor.item().toDouble(); + TORCH_CHECK( + std::isfinite(*total_norm), + "The total norm of order ", + norm_type, + " for gradients from `parameters` ", + "is non-finite, so it cannot be clipped. To disable this error and scale ", + "the gradients with the non-finite norm anyway, set ", + "`error_if_nonfinite=false`"); + } + + auto clip_coef = max_norm / (total_norm_tensor + 1e-6); + auto clip_coef_clamped = + torch::clamp(clip_coef, std::nullopt /* min */, 1.0 /* max */); + for (auto& param : params_with_grad) { + param.grad().data().mul_(clip_coef_clamped); + } + + if (!total_norm.has_value()) { + total_norm = total_norm_tensor.item().toDouble(); + } + return *total_norm; +} + +// A wrapper around clip_grad_norm_ that allows us to call the function with a +// braced-init-list of Tensors. +inline double clip_grad_norm_( + std::initializer_list parameters, + double max_norm, + double norm_type = 2.0, + bool error_if_nonfinite = false) { + return clip_grad_norm_( + std::vector(parameters), max_norm, norm_type, error_if_nonfinite); +} + +// A wrapper around clip_grad_norm_ that allows us to call the function with a +// single Tensor. +inline double clip_grad_norm_( + Tensor parameter, + double max_norm, + double norm_type = 2.0, + bool error_if_nonfinite = false) { + std::vector params = {std::move(parameter)}; + return clip_grad_norm_(params, max_norm, norm_type, error_if_nonfinite); +} + +// Clips gradient of an iterable of parameters at specified value. +// Gradients are modified in-place. +// See https://pytorch.org/docs/stable/nn.html#clip-grad-value +// for more details about this module. +inline void clip_grad_value_( + const std::vector& parameters, + double clip_value) { + for (const auto& param : parameters) { + if (param.grad().defined()) { + param.grad().data().clamp_(-clip_value, clip_value); + } + } +} + +// A wrapper around clip_grad_value_ that allows us to call the function with a +// braced-init-list of Tensors. +inline void clip_grad_value_( + std::initializer_list parameters, + double clip_value) { + clip_grad_value_(std::vector(parameters), clip_value); +} + +// A wrapper around clip_grad_value_ that allows us to call the function with a +// single Tensor. +inline void clip_grad_value_(Tensor parameter, double clip_value) { + std::vector params = {std::move(parameter)}; + clip_grad_value_(params, clip_value); +} + +} // namespace torch::nn::utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/convert_parameters.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/convert_parameters.h new file mode 100644 index 0000000000000000000000000000000000000000..8d0bfdd376a7c71b2996ccb6c3d1fd2466c627cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/convert_parameters.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +namespace torch::nn::utils { + +// This helper function is to check if the parameters are located +// in the same device. Currently, the conversion between model parameters +// and single vector form is not supported for multiple allocations, +// e.g. parameters in different GPUs, or mixture of CPU/GPU. +inline std::optional _check_param_device( + const torch::Tensor& param, + std::optional old_param_device) { + // Meet the first parameter + if (old_param_device == std::nullopt) { + old_param_device = param.is_cuda() ? param.get_device() : -1; + } else { + bool warn = false; + if (param.is_cuda()) { // Check if in same GPU + warn = (param.get_device() != old_param_device); + } else { // Check if in CPU + warn = (old_param_device != -1); + } + if (warn) { + TORCH_CHECK( + false, + "Found two parameters on different devices, ", + "this is currently not supported."); + } + } + + return old_param_device; +} + +// Convert parameters to one vector +inline torch::Tensor parameters_to_vector( + const std::vector& parameters) { + std::optional param_device; + + std::vector vec; + vec.reserve(parameters.size()); + + for (const torch::Tensor& param : parameters) { + // Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device); + + vec.push_back(param.view(-1)); + } + + return torch::cat(vec); +} + +// Convert one vector to the parameters +inline void vector_to_parameters( + const torch::Tensor& vec, + const std::vector& parameters) { + // Flag for the device where the parameter is located + std::optional param_device; + + // Pointer for slicing the vector for each parameter + int64_t pointer = 0; + for (const torch::Tensor& param : parameters) { + // Ensure the parameters are located in the same device + param_device = _check_param_device(param, param_device); + + // The length of the parameter + auto num_param = param.numel(); + // Slice the vector, reshape it, and replace the old data of the parameter + param.set_data( + vec.slice(0, pointer, pointer + num_param).view_as(param).data()); + + // Increment the pointer + pointer += num_param; + } +} + +} // namespace torch::nn::utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/rnn.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/rnn.h new file mode 100644 index 0000000000000000000000000000000000000000..ab9a9b97c32c563a29b366dbe1feb11164dfacf1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/nn/utils/rnn.h @@ -0,0 +1,348 @@ +#pragma once + +#include +#include + +#include + +namespace torch::nn::utils::rnn { + +inline Tensor invert_permutation(const Tensor& permutation) { + if (!permutation.defined()) { + return torch::Tensor(); + } + Tensor output = + torch::empty_like(permutation, torch::MemoryFormat::Contiguous); + output.scatter_( + 0, + permutation, + torch::arange(0, permutation.numel(), permutation.device())); + return output; +} + +/// Holds the data and list of `batch_sizes` of a packed sequence. +/// +/// All RNN modules accept packed sequences as inputs. +/// +/// Note: +/// Instances of this class should never be created manually. They are meant +/// to be instantiated by functions like `pack_padded_sequence`. +/// +/// Batch sizes represent the number elements at each sequence step in +/// the batch, not the varying sequence lengths passed to +/// `pack_padded_sequence`. For instance, given data ``abc`` and ``x`` +/// the :class:`PackedSequence` would contain data ``axbc`` with +/// ``batch_sizes=[2,1,1]``. +/// +/// Attributes: +/// data (Tensor): Tensor containing packed sequence +/// batch_sizes (Tensor): Tensor of integers holding +/// information about the batch size at each sequence step +/// sorted_indices (Tensor, optional): Tensor of integers holding how this +/// :class:`PackedSequence` is constructed from sequences. +/// unsorted_indices (Tensor, optional): Tensor of integers holding how this +/// to recover the original sequences with correct order. +/// +/// .. note:: +/// `data` can be on arbitrary device and of arbitrary dtype. +/// `sorted_indices` and `unsorted_indices` must be ``torch::kInt64`` +/// tensors on the same device as `data`. +/// +/// However, `batch_sizes` should always be a CPU ``torch::kInt64`` tensor. +/// +/// This invariant is maintained throughout `PackedSequence` class, +/// and all functions that construct a `PackedSequence` in libtorch +/// (i.e., they only pass in tensors conforming to this constraint). +class PackedSequence { + public: + explicit PackedSequence( + Tensor data, + Tensor batch_sizes, + Tensor sorted_indices = {}, + Tensor unsorted_indices = {}) { + // NB: if unsorted_indices is provided, it should be the inverse permutation + // to sorted_indices. Don't assert it here because the PackedSequence ctor + // should only be used internally. + if (!unsorted_indices.defined()) { + unsorted_indices = invert_permutation(sorted_indices); + } + TORCH_CHECK( + batch_sizes.device().type() == kCPU, + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn::utils::rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence"); + data_ = std::move(data); + batch_sizes_ = std::move(batch_sizes); + sorted_indices_ = std::move(sorted_indices); + unsorted_indices_ = std::move(unsorted_indices); + } + + const Tensor& data() const { + return data_; + } + + const Tensor& batch_sizes() const { + return batch_sizes_; + } + + const Tensor& sorted_indices() const { + return sorted_indices_; + } + + const Tensor& unsorted_indices() const { + return unsorted_indices_; + } + + PackedSequence pin_memory() const { + // Why not convert `batch_sizes`? + // See NOTE [ device and dtype of a PackedSequence ] + return PackedSequence( + data_.pin_memory(), + batch_sizes_, + sorted_indices_.defined() ? sorted_indices_.pin_memory() : Tensor(), + unsorted_indices_.defined() ? unsorted_indices_.pin_memory() + : Tensor()); + } + + PackedSequence to(TensorOptions options) const { + // Performs dtype and/or device conversion on `data_`. + // + // If the ``data_`` Tensor already has the correct `torch::Dtype` + // and `torch::Device`, then ``self`` is returned. + // Otherwise, returns a copy with the desired configuration. + + // Why not convert `batch_sizes`? + // See NOTE [ device and dtype of a PackedSequence ] + Tensor data = data_.to(options); + if (data.is_same(data_)) { + return *this; + } else { + // Does not forward device or dtype args, device is set from data.device() + Tensor sorted_indices = sorted_indices_.defined() + ? sorted_indices_.to( + options.device(data.device()).dtype(sorted_indices_.dtype())) + : Tensor(); + Tensor unsorted_indices = unsorted_indices_.defined() + ? unsorted_indices_.to( + options.device(data.device()).dtype(unsorted_indices_.dtype())) + : Tensor(); + return PackedSequence( + std::move(data), + batch_sizes_, + std::move(sorted_indices), + std::move(unsorted_indices)); + } + } + + PackedSequence cuda() const { + return to(kCUDA); + } + + PackedSequence cpu() const { + return to(kCPU); + } + + /// Returns true if `data_` stored on a gpu + bool is_cuda() const { + return data_.is_cuda(); + } + + /// Returns true if `data_` stored on in pinned memory + bool is_pinned() const { + return data_.is_pinned(); + } + + private: + Tensor data_; + Tensor batch_sizes_; + Tensor sorted_indices_; + Tensor unsorted_indices_; +}; + +/// Packs a Tensor containing padded sequences of variable length. +/// +/// `input` can be of size ``T x B x *`` where `T` is the length of the +/// longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and +/// ``*`` is any number of dimensions (including 0). If ``batch_first`` is +/// ``true``, ``B x T x *`` `input` is expected. +/// +/// For unsorted sequences, use `enforce_sorted = false`. If `enforce_sorted` is +/// ``true``, the sequences should be sorted by length in a decreasing order, +/// i.e. +/// ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the +/// shortest one. +/// +/// Note: +/// This function accepts any input that has at least two dimensions. You +/// can apply it to pack the labels, and use the output of the RNN with +/// them to compute the loss directly. A Tensor can be retrieved from +/// a `PackedSequence` object by calling its ``.data()`` function. +/// +/// Arguments: +/// input (Tensor): padded batch of variable length sequences. +/// lengths (Tensor): list of sequences lengths of each batch element. +/// batch_first (bool, optional): if ``true``, the input is expected in ``B +/// x T x *`` +/// format. Default: ``false``. +/// enforce_sorted (bool, optional): if ``true``, the input is expected to +/// contain sequences sorted by length in a decreasing order. If +/// ``false``, this condition is not checked. Default: ``true``. +/// +/// Returns: +/// a `PackedSequence` object +inline PackedSequence pack_padded_sequence( + Tensor input, + Tensor lengths, + bool batch_first = false, + bool enforce_sorted = true) { + lengths = lengths.to(kInt64); + Tensor sorted_indices; + if (enforce_sorted) { + sorted_indices = Tensor(); + } else { + std::tie(lengths, sorted_indices) = + torch::sort(lengths, /*dim=*/-1, /*descending=*/true); + sorted_indices = sorted_indices.to(input.device()); + int64_t batch_dim = batch_first ? 0 : 1; + input = input.index_select(batch_dim, sorted_indices); + } + + auto [data, batch_sizes] = + torch::_pack_padded_sequence(input, lengths, batch_first); + return PackedSequence( + std::move(data), std::move(batch_sizes), std::move(sorted_indices), {}); +} + +/// Pads a packed batch of variable length sequences. +/// +/// It is an inverse operation to `pack_padded_sequence`. +/// +/// The returned Tensor's data will be of size ``T x B x *``, where `T` is the +/// length of the longest sequence and `B` is the batch size. If ``batch_first`` +/// is true, the data will be transposed into ``B x T x *`` format. +/// +/// Batch elements will be ordered decreasingly by their length. +/// +/// Arguments: +/// sequence (PackedSequence): batch to pad +/// batch_first (bool, optional): if ``true``, the output will be in ``B x T +/// x *`` +/// format. +/// padding_value (double, optional): values for padded elements. +/// total_length (int64_t, optional): if specified, the output will be +/// padded to +/// have length `total_length`. This method will throw error +/// if `total_length` is less than the max sequence length in +/// `sequence`. +/// +/// Returns: +/// Tuple of Tensor containing the padded sequence, and a Tensor +/// containing the list of lengths of each sequence in the batch. +inline std::tuple pad_packed_sequence( + const PackedSequence& sequence, + bool batch_first = false, + double padding_value = 0.0, + std::optional total_length = std::nullopt) { + int64_t max_seq_length = sequence.batch_sizes().size(0); + if (total_length.has_value()) { + int64_t total_length_val = total_length.value(); + TORCH_CHECK( + total_length_val >= max_seq_length, + "Expected total_length to be at least the length " + "of the longest sequence in input, but got " + "total_length=", + total_length_val, + " and max sequence length being ", + max_seq_length); + max_seq_length = total_length_val; + } + auto [padded_output, lengths] = torch::_pad_packed_sequence( + sequence.data(), + sequence.batch_sizes(), + batch_first, + padding_value, + max_seq_length); + const Tensor& unsorted_indices = sequence.unsorted_indices(); + if (unsorted_indices.defined()) { + int64_t batch_dim = batch_first ? 0 : 1; + return std::make_tuple( + padded_output.index_select(batch_dim, unsorted_indices), + lengths.index({unsorted_indices.cpu()})); + } + return std::make_tuple(padded_output, lengths); +} + +/// Pad a list of variable length Tensors with ``padding_value`` +/// +/// ``pad_sequence`` stacks a list of Tensors along a new dimension, +/// and pads them to equal length. For example, if the input is list of +/// sequences with size ``L x *`` and if batch_first is false, and ``T x B x *`` +/// otherwise. +/// +/// `B` is batch size. It is equal to the number of elements in ``sequences``. +/// `T` is length of the longest sequence. +/// `L` is length of the sequence. +/// `*` is any number of trailing dimensions, including none. +/// +/// Note: +/// This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` +/// where `T` is the length of the longest sequence. This function assumes +/// trailing dimensions and type of all the Tensors in sequences are same. +/// +/// Arguments: +/// sequences (torch::ArrayRef): list of variable length sequences. +/// batch_first (bool, optional): output will be in ``B x T x *`` if true, +/// or in +/// ``T x B x *`` otherwise +/// padding_value (double, optional): value for padded elements. Default: 0. +/// padding_side (str, optional): the side to pad the sequences on. Default: +/// "right". +/// +/// Returns: +/// Tensor of size ``T x B x *`` if `batch_first` is ``false``. +/// Tensor of size ``B x T x *`` otherwise +inline Tensor pad_sequence( + ArrayRef sequences, + bool batch_first = false, + double padding_value = 0, + std::string_view padding_side = "right") { + return at::pad_sequence(sequences, batch_first, padding_value, padding_side); +} + +/// Packs a list of variable length Tensors +/// +/// ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is +/// the length of a sequence and `*` is any number of trailing dimensions, +/// including zero. +/// +/// For unsorted sequences, use `enforce_sorted = false`. If ``enforce_sorted`` +/// is ``true``, the sequences should be sorted in the order of decreasing +/// length. +/// +/// +/// Arguments: +/// sequences (torch::ArrayRef): A list of sequences of decreasing +/// length. enforce_sorted (bool, optional): if ``true``, checks that the +/// input +/// contains sequences sorted by length in a decreasing order. If +/// ``false``, this condition is not checked. Default: ``true``. +/// +/// Returns: +/// a `PackedSequence` object +inline PackedSequence pack_sequence( + ArrayRef sequences, + bool enforce_sorted = true) { + Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64); + for (const auto i : c10::irange(sequences.size())) { + lengths[static_cast(i)] = sequences[i].size(0); + } + return pack_padded_sequence( + at::pad_sequence(sequences), + std::move(lengths), + /*batch_first=*/false, + /*enforce_sorted=*/enforce_sorted); +} + +} // namespace torch::nn::utils::rnn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim.h new file mode 100644 index 0000000000000000000000000000000000000000..d52b34c076dd095ba8d0c6e4820001def3d24d0c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adagrad.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adagrad.h new file mode 100644 index 0000000000000000000000000000000000000000..1ad6186faeabbc0ea84c3595c817b1a0157a28e6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adagrad.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::optim { + +struct TORCH_API AdagradOptions + : public OptimizerCloneableOptions { + AdagradOptions(double lr = 1e-2); + TORCH_ARG(double, lr) = 1e-2; + TORCH_ARG(double, lr_decay) = 0; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(double, initial_accumulator_value) = 0; + TORCH_ARG(double, eps) = 1e-10; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdagradOptions& lhs, + const AdagradOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API AdagradParamState + : public OptimizerCloneableParamState { + TORCH_ARG(torch::Tensor, sum); + TORCH_ARG(int64_t, step) = 0; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdagradParamState& lhs, + const AdagradParamState& rhs); +}; + +class TORCH_API Adagrad : public Optimizer { + public: + explicit Adagrad( + const std::vector& param_groups, + AdagradOptions defaults = {}) + : Optimizer(param_groups, std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK( + defaults.lr_decay() >= 0, + "Invalid lr_decay value: ", + defaults.lr_decay()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + defaults.initial_accumulator_value() >= 0, + "Invalid initial_accumulator_value value: ", + defaults.initial_accumulator_value()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + + for (const auto& group : param_groups_) { + for (const auto& p : group.params()) { + auto state = std::make_unique(); + state->step(0); + state->sum(torch::full_like( + p.data(), + defaults.initial_accumulator_value(), + at::MemoryFormat::Preserve)); + state_[p.unsafeGetTensorImpl()] = std::move(state); + } + } + } + + explicit Adagrad(std::vector params, AdagradOptions defaults = {}) + : Adagrad({OptimizerParamGroup(std::move(params))}, std::move(defaults)) { + } + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Adagrad); + } +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adam.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adam.h new file mode 100644 index 0000000000000000000000000000000000000000..6c06e4030cf4cbd43684481e24aad002444e7797 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adam.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::optim { + +struct TORCH_API AdamOptions : public OptimizerCloneableOptions { + AdamOptions(double lr = 1e-3); + TORCH_ARG(double, lr) = 1e-3; + typedef std::tuple betas_t; + TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999); + TORCH_ARG(double, eps) = 1e-8; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(bool, amsgrad) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamOptions& lhs, + const AdamOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API AdamParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, step) = 0; + TORCH_ARG(torch::Tensor, exp_avg); + TORCH_ARG(torch::Tensor, exp_avg_sq); + TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {}; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamParamState& lhs, + const AdamParamState& rhs); +}; + +class TORCH_API Adam : public Optimizer { + public: + explicit Adam( + const std::vector& param_groups, + AdamOptions defaults = {}) + : Optimizer(param_groups, std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + auto betas = defaults.betas(); + TORCH_CHECK( + 0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, + "Invalid beta parameter at index 0: ", + std::get<0>(betas)); + TORCH_CHECK( + 0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, + "Invalid beta parameter at index 1: ", + std::get<1>(betas)); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + } + explicit Adam(std::vector params, AdamOptions defaults = {}) + : Adam({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Adam); + } +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adamw.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adamw.h new file mode 100644 index 0000000000000000000000000000000000000000..d656921a719d0cfab6c992bab7f73f509c2c8b97 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/adamw.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::optim { + +struct TORCH_API AdamWOptions : public OptimizerCloneableOptions { + AdamWOptions(double lr = 1e-3); + TORCH_ARG(double, lr) = 1e-3; + typedef std::tuple betas_t; + TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999); + TORCH_ARG(double, eps) = 1e-8; + TORCH_ARG(double, weight_decay) = 1e-2; + TORCH_ARG(bool, amsgrad) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamWOptions& lhs, + const AdamWOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API AdamWParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, step) = 0; + TORCH_ARG(torch::Tensor, exp_avg); + TORCH_ARG(torch::Tensor, exp_avg_sq); + TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {}; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const AdamWParamState& lhs, + const AdamWParamState& rhs); +}; + +class TORCH_API AdamW : public Optimizer { + public: + explicit AdamW( + const std::vector& param_groups, + AdamWOptions defaults = {}) + : Optimizer(param_groups, std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + auto betas = defaults.betas(); + TORCH_CHECK( + 0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0, + "Invalid beta parameter at index 0: ", + std::get<0>(betas)); + TORCH_CHECK( + 0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0, + "Invalid beta parameter at index 1: ", + std::get<1>(betas)); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + } + explicit AdamW(std::vector params, AdamWOptions defaults = {}) + : AdamW({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(AdamW); + } +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/lbfgs.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/lbfgs.h new file mode 100644 index 0000000000000000000000000000000000000000..3d5f1832cf600d2705687efc5f295ff07e64e0a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/lbfgs.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::optim { + +struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions { + LBFGSOptions(double lr = 1); + TORCH_ARG(double, lr) = 1; + TORCH_ARG(int64_t, max_iter) = 20; + TORCH_ARG(std::optional, max_eval) = std::nullopt; + TORCH_ARG(double, tolerance_grad) = 1e-7; + TORCH_ARG(double, tolerance_change) = 1e-9; + TORCH_ARG(int64_t, history_size) = 100; + TORCH_ARG(std::optional, line_search_fn) = std::nullopt; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const LBFGSOptions& lhs, + const LBFGSOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API LBFGSParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, func_evals) = 0; + TORCH_ARG(int64_t, n_iter) = 0; + TORCH_ARG(double, t) = 0; + TORCH_ARG(double, prev_loss) = 0; + TORCH_ARG(Tensor, d) = {}; + TORCH_ARG(Tensor, H_diag) = {}; + TORCH_ARG(Tensor, prev_flat_grad) = {}; + TORCH_ARG(std::deque, old_dirs); + TORCH_ARG(std::deque, old_stps); + TORCH_ARG(std::deque, ro); + TORCH_ARG(std::optional>, al) = std::nullopt; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const LBFGSParamState& lhs, + const LBFGSParamState& rhs); +}; + +class TORCH_API LBFGS : public Optimizer { + public: + explicit LBFGS( + const std::vector& param_groups, + LBFGSOptions defaults = {}) + : Optimizer(param_groups, std::make_unique(defaults)) { + TORCH_CHECK( + param_groups_.size() == 1, + "LBFGS doesn't support per-parameter options (parameter groups)"); + if (defaults.max_eval() == std::nullopt) { + auto max_eval_val = (defaults.max_iter() * 5) / 4; + static_cast(param_groups_[0].options()) + .max_eval(max_eval_val); + static_cast(*defaults_).max_eval(max_eval_val); + } + _numel_cache = std::nullopt; + } + explicit LBFGS(std::vector params, LBFGSOptions defaults = {}) + : LBFGS({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} + + Tensor step(LossClosure closure) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + std::optional _numel_cache; + int64_t _numel(); + Tensor _gather_flat_grad(); + void _add_grad(const double step_size, const Tensor& update); + std::tuple _directional_evaluate( + const LossClosure& closure, + const std::vector& x, + double t, + const Tensor& d); + void _set_param(const std::vector& params_data); + std::vector _clone_param(); + + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(LBFGS); + } +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/optimizer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..e2af5b74ee599a7cc576b4d482dedcf10f8ec409 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/optimizer.h @@ -0,0 +1,223 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +// Forward declarations confuse Doxygen +#ifndef DOXYGEN_SHOULD_SKIP_THIS +namespace at { +class Tensor; +} // namespace at + +namespace torch { +using at::Tensor; +namespace serialize { +class OutputArchive; +class InputArchive; +} // namespace serialize +} // namespace torch +#endif // DOXYGEN_SHOULD_SKIP_THIS + +namespace torch::optim { + +class TORCH_API OptimizerParamState { + public: + OptimizerParamState() = default; + OptimizerParamState(const OptimizerParamState&) = default; + OptimizerParamState& operator=(const OptimizerParamState&) = default; + OptimizerParamState(OptimizerParamState&&) noexcept = default; + OptimizerParamState& operator=(OptimizerParamState&&) noexcept = default; + virtual std::unique_ptr clone() const; + virtual void serialize(torch::serialize::InputArchive& archive); + virtual void serialize(torch::serialize::OutputArchive& archive) const; + virtual ~OptimizerParamState() = default; +}; + +template +class OptimizerCloneableParamState : public OptimizerParamState { + std::unique_ptr clone() const override { + return std::make_unique(static_cast(*this)); + } +}; + +class TORCH_API OptimizerOptions { + public: + OptimizerOptions() = default; + OptimizerOptions(const OptimizerOptions&) = default; + OptimizerOptions& operator=(const OptimizerOptions&) = default; + OptimizerOptions(OptimizerOptions&&) noexcept = default; + OptimizerOptions& operator=(OptimizerOptions&&) noexcept = default; + virtual std::unique_ptr clone() const; + virtual void serialize(torch::serialize::InputArchive& archive); + virtual void serialize(torch::serialize::OutputArchive& archive) const; + virtual ~OptimizerOptions() = default; + virtual double get_lr() const; + virtual void set_lr(const double lr); +}; + +template +class OptimizerCloneableOptions : public OptimizerOptions { + private: + std::unique_ptr clone() const override { + return std::make_unique(static_cast(*this)); + } +}; + +/// Stores parameters in the param_group and stores a pointer to the +/// OptimizerOptions +class TORCH_API OptimizerParamGroup { + public: + // NOTE: In order to store `OptimizerParamGroup` in a `std::vector`, it has to + // be copy-constructible. + OptimizerParamGroup(const OptimizerParamGroup& param_group) + : params_(param_group.params()), + options_( + param_group.has_options() ? param_group.options().clone() + : nullptr) {} + OptimizerParamGroup(OptimizerParamGroup&& param_group) = default; + OptimizerParamGroup(std::vector params) + : params_(std::move(params)) {} + OptimizerParamGroup( + std::vector params, + std::unique_ptr options) + : params_(std::move(params)), options_(std::move(options)) {} + + OptimizerParamGroup& operator=(const OptimizerParamGroup& param_group) = + delete; + OptimizerParamGroup& operator=(OptimizerParamGroup&& param_group) noexcept = + default; + ~OptimizerParamGroup() = default; + bool has_options() const; + OptimizerOptions& options(); + const OptimizerOptions& options() const; + void set_options(std::unique_ptr options); + std::vector& params(); + const std::vector& params() const; + + protected: + std::vector params_; + std::unique_ptr options_; +}; + +class TORCH_API Optimizer { + public: + // The copy constructor is deleted, because the user should use the + // `state_dict` / `load_state_dict` API to copy an optimizer instead. + Optimizer(const Optimizer& optimizer) = delete; + Optimizer(Optimizer&& optimizer) = default; + Optimizer& operator=(const Optimizer& optimizer) = delete; + Optimizer& operator=(Optimizer&& optimizer) = default; + + explicit Optimizer( + const std::vector& param_groups, + std::unique_ptr defaults) + : defaults_(std::move(defaults)) { + for (const auto& param_group : param_groups) { + add_param_group(param_group); + } + } + + /// Constructs the `Optimizer` from a vector of parameters. + explicit Optimizer( + std::vector parameters, + std::unique_ptr defaults) + : Optimizer( + {OptimizerParamGroup(std::move(parameters))}, + std::move(defaults)) {} + + /// Adds the given param_group to the optimizer's param_group list. + void add_param_group(const OptimizerParamGroup& param_group); + + virtual ~Optimizer() = default; + + using LossClosure = std::function; + /// A loss function closure, which is expected to return the loss value. + virtual Tensor step(LossClosure closure = nullptr) = 0; + + /// Adds the given vector of parameters to the optimizer's parameter list. + void add_parameters(const std::vector& parameters); + + /// Zeros out the gradients of all parameters. + void zero_grad(bool set_to_none = true); + + /// Provides a const reference to the parameters in the first param_group this + /// optimizer holds. + const std::vector& parameters() const noexcept; + + /// Provides a reference to the parameters in the first param_group this + /// optimizer holds. + std::vector& parameters() noexcept; + + /// Returns the number of parameters referenced by the optimizer. + size_t size() const noexcept; + + OptimizerOptions& defaults() noexcept; + + const OptimizerOptions& defaults() const noexcept; + + /// Provides a reference to the param_groups this optimizer holds. + std::vector& param_groups() noexcept; + + /// Provides a const reference to the param_groups this optimizer holds. + const std::vector& param_groups() const noexcept; + + /// Provides a reference to the state this optimizer holds + ska::flat_hash_map>& + state() noexcept; + + /// Provides a const reference to the state this optimizer holds + const ska::flat_hash_map>& state() + const noexcept; + + /// Serializes the optimizer state into the given `archive`. + virtual void save(serialize::OutputArchive& archive) const; + + /// Deserializes the optimizer state from the given `archive`. + virtual void load(serialize::InputArchive& archive); + + protected: + std::vector param_groups_; + ska::flat_hash_map> state_; + std::unique_ptr defaults_; +}; + +/* How do we decide whether to serialize undefined tensors or + std::nullopt values into the output archive? +Answer: we strictly follow the behavior of Python API. To be more specific: + +For optimizer options: +a) For undefined tensor: currently no tensor is used as an options argument in +Python API, so we don't need to worry about it now. b) For std::nullopt value: +we serialize std::nullopt values into the output archive, to follow the exact +same behavior as Python API. + +For optimizer param state: +a) For undefined tensor: in param state, undefined tensor in C++ impl is +equivalent to missing key in Python impl. Since we don't serialize missing keys +in Python API, we skip undefined tensors when serializing the param state. b) +For std::nullopt value: in param state, std::nullopt value in C++ impl is +equivalent to missing key in Python impl. Since we don't serialize missing keys +in Python API, we skip std::nullopt values when serializing the param state. */ + +/// Serializes an `Optimizer` into an `OutputArchive`. +TORCH_API serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const Optimizer& optimizer); + +/// Deserializes a `Tensor` from an `InputArchive`. +TORCH_API serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + Optimizer& optimizer); + +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/rmsprop.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/rmsprop.h new file mode 100644 index 0000000000000000000000000000000000000000..7b6b9dea5649f5cdff6adcc3128b135580bdb4db --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/rmsprop.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::optim { + +struct TORCH_API RMSpropOptions + : public OptimizerCloneableOptions { + RMSpropOptions(double lr = 1e-2); + TORCH_ARG(double, lr) = 1e-2; + TORCH_ARG(double, alpha) = 0.99; + TORCH_ARG(double, eps) = 1e-8; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(double, momentum) = 0; + TORCH_ARG(bool, centered) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const RMSpropOptions& lhs, + const RMSpropOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API RMSpropParamState + : public OptimizerCloneableParamState { + TORCH_ARG(int64_t, step) = 0; + TORCH_ARG(torch::Tensor, square_avg); + TORCH_ARG(torch::Tensor, momentum_buffer) = {}; + TORCH_ARG(torch::Tensor, grad_avg) = {}; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const RMSpropParamState& lhs, + const RMSpropParamState& rhs); +}; + +class TORCH_API RMSprop : public Optimizer { + public: + explicit RMSprop( + const std::vector& param_groups, + RMSpropOptions defaults = {}) + : Optimizer(param_groups, std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); + TORCH_CHECK( + defaults.momentum() >= 0, + "Invalid momentum value: ", + defaults.momentum()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + defaults.alpha() >= 0, "Invalid alpha value: ", defaults.alpha()); + } + + explicit RMSprop(std::vector params, RMSpropOptions defaults = {}) + : RMSprop({OptimizerParamGroup(std::move(params))}, std::move(defaults)) { + } + + torch::Tensor step(LossClosure closure = nullptr) override; + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(RMSprop); + } +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..4fe9181c499200a29fb2eb4afa19fae8179184bf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +#include + +namespace torch::optim { + +class TORCH_API LRScheduler { + public: + // This class needs to take a reference of an optimizer from outside such that + // it can modify its learning rates; due to this the lifetime of said + // optimizer must be maintained + LRScheduler(torch::optim::Optimizer& optimizer); + + virtual ~LRScheduler() = default; + + void step(); + + protected: + // A vector of learning rates is calculated and returned from the specific + // subclass. A vector is returned with each element being a separate learning + // rate for each param group - although the normal use case would be to return + // a vector of identical elements. + virtual std::vector get_lrs() = 0; + + // Get current learning rates from the optimizer + std::vector get_current_lrs() const; + + unsigned step_count_{}; + + private: + void set_optimizer_lrs(const std::vector& learning_rates); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + torch::optim::Optimizer& optimizer_; +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..17c89816d79d3ab3a73b54614d9f248874de2f56 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#include + +#include + +namespace torch::optim { + +class TORCH_API ReduceLROnPlateauScheduler { + public: + enum SchedulerMode { min, max }; + enum ThresholdMode { rel, abs }; + ReduceLROnPlateauScheduler( + Optimizer& optimizer, + SchedulerMode mode = min, + float factor = 0.1, + int patience = 10, + double threshold = 1e-4, + ThresholdMode threshold_mode = rel, + int cooldown = 0, + const std::vector& min_lr = std::vector(), + double eps = 1e-8, + bool verbose = false); + + virtual ~ReduceLROnPlateauScheduler() = default; + + void step(float metric); + + private: + void reset(); + void reduce_lr(int epoch); + bool in_cooldown() const; + bool is_better(float a); + void init_is_better( + SchedulerMode mode, + double threshold, + ThresholdMode threshold_mode); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + Optimizer& optimizer; + SchedulerMode mode{}; + float mode_worse{}; + float factor; + int patience; + double threshold{}; + ThresholdMode threshold_mode{}; + int cooldown{}; + int cooldown_counter{}; + std::vector min_lrs; + double eps; + float best{}; + bool verbose; + int last_epoch{}; + int num_bad_epochs{}; +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/step_lr.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/step_lr.h new file mode 100644 index 0000000000000000000000000000000000000000..f46b274f518bd5f116c3616341e8e1a9c070715c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/schedulers/step_lr.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace torch::optim { + +class TORCH_API StepLR : public LRScheduler { + public: + StepLR( + torch::optim::Optimizer& optimizer, + const unsigned step_size, + const double gamma = 0.1); + + private: + std::vector get_lrs() override; + + const unsigned step_size_; + const double gamma_; +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/serialize.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..50f66782f276300bf234378f4b542aaedd278860 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/serialize.h @@ -0,0 +1,315 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::optim { +namespace detail { +// Utility function to save state +template +void serialize( + serialize::OutputArchive& archive, + const ska::flat_hash_map>& + state) { + for (const auto& item : state) { + serialize::OutputArchive param_state_archive(archive.compilation_unit()); + std::string tensorimpl_key = + std::to_string(reinterpret_cast(item.first)); + const DerivedOptimizerParamState& curr_state = + static_cast(*(item.second)); + curr_state.serialize(param_state_archive); + archive.write(tensorimpl_key, param_state_archive); + } +} + +// Utility function to load state +template +void serialize( + serialize::InputArchive& archive, + ska::flat_hash_map>& state) { + std::vector tensorimpl_keys = archive.keys(); + for (const std::string& tensorimpl_key : tensorimpl_keys) { + serialize::InputArchive param_state_archive; + archive.read(tensorimpl_key, param_state_archive); + DerivedOptimizerParamState param_state; + param_state.serialize(param_state_archive); + // NOLINTNEXTLINE(performance-no-int-to-ptr) + state[reinterpret_cast(std::stoull(tensorimpl_key))] = + std::make_unique(param_state); + } +} + +// Utility function to save param_groups +template +void serialize( + serialize::OutputArchive& archive, + const std::vector& param_groups) { + archive.write( + "param_groups/size", + torch::tensor(static_cast(param_groups.size()))); + for (const auto i : c10::irange(param_groups.size())) { + serialize::OutputArchive param_group_archive(archive.compilation_unit()); + std::vector params = param_groups[i].params(); + param_group_archive.write( + "params/size", torch::tensor(static_cast(params.size()))); + for (const auto index : c10::irange(params.size())) { + param_group_archive.write( + "params/" + std::to_string(index), + IValue(std::to_string( + reinterpret_cast(params[index].unsafeGetTensorImpl())))); + } + const DerivedOptimizerParamOptions& param_group_options = + static_cast( + param_groups[i].options()); + serialize::OutputArchive param_group_options_archive( + param_group_archive.compilation_unit()); + param_group_options.serialize(param_group_options_archive); + param_group_archive.write("options", param_group_options_archive); + archive.write("param_groups/" + std::to_string(i), param_group_archive); + } +} + +// Utility function to load param_groups +// We take as input vector of pair of string and unique_ptr to optimizer options +// so that we can retain the state for each param by using the old tensor impl +// keys (saved during serialization) and map the new tensor impl keys to the +// correct state for each param +template +void serialize( + serialize::InputArchive& archive, + std::vector< + std::pair, std::unique_ptr>>& + param_groups) { + torch::Tensor param_groups_size_tensor; + archive.read("param_groups/size", param_groups_size_tensor); + const int64_t param_groups_size = param_groups_size_tensor.item(); + for (const auto i : c10::irange(param_groups_size)) { + serialize::InputArchive param_group_archive; + archive.read("param_groups/" + std::to_string(i), param_group_archive); + torch::Tensor size_tensor; + param_group_archive.read("params/size", size_tensor); + const int64_t size = size_tensor.item(); + std::vector params; + for (const auto index : c10::irange(size)) { + IValue ivalue; + param_group_archive.read("params/" + std::to_string(index), ivalue); + std::string element = ivalue.toStringRef(); + params.emplace_back(element); + } + serialize::InputArchive param_group_options_archive; + param_group_archive.read("options", param_group_options_archive); + DerivedOptimizerParamOptions param_group_options(0); + param_group_options.serialize(param_group_options_archive); + param_groups.emplace_back(std::make_pair( + params, + std::make_unique(param_group_options))); + } +} +} // namespace detail + +// Note: These functions are all called `serialize()` so they can be called +// inside a template where the archive type is a template type and can thus be +// passed such that the appropriate overload is selected. + +/// Utility function to save a value of `int64_t` type. +void serialize( + serialize::OutputArchive& archive, + const std::string& key, + const int64_t& value); + +/// Utility function to load a value of `int64_t` type. +void serialize( + serialize::InputArchive& archive, + const std::string& key, + int64_t& value); + +/// Utility function to save a vector of step buffers. +void serialize( + serialize::OutputArchive& archive, + const std::string& key, + const std::vector& steps); + +/// Utility function to load a vector of step buffers. +void serialize( + serialize::InputArchive& archive, + const std::string& key, + std::vector& steps); + +// Utility function to save state and param_groups +template < + typename DerivedOptimizerParamState, + typename DerivedOptimizerParamOptions> +void serialize(serialize::OutputArchive& archive, const Optimizer& optimizer) { + archive.write("pytorch_version", IValue("1.5.0")); + serialize::OutputArchive state_archive(archive.compilation_unit()); + detail::serialize( + state_archive, optimizer.state()); + archive.write("state", state_archive); + + serialize::OutputArchive param_groups_archive(archive.compilation_unit()); + detail::serialize( + param_groups_archive, optimizer.param_groups()); + archive.write("param_groups", param_groups_archive); +} + +// Utility function to load state and param_groups and update state +template < + typename DerivedOptimizerParamState, + typename DerivedOptimizerParamOptions> +void serialize(serialize::InputArchive& archive, Optimizer& optimizer) { + IValue pytorch_version; + archive.read("pytorch_version", pytorch_version); + TORCH_INTERNAL_ASSERT(pytorch_version.toStringRef() == "1.5.0"); + serialize::InputArchive state_archive; + archive.read("state", state_archive); + ska::flat_hash_map> saved_state; + detail::serialize(state_archive, saved_state); + + serialize::InputArchive param_groups_archive; + archive.read("param_groups", param_groups_archive); + std::vector< + std::pair, std::unique_ptr>> + saved_param_groups; + detail::serialize( + param_groups_archive, saved_param_groups); + + // update state and optimizer options + TORCH_CHECK( + saved_param_groups.size() == optimizer.param_groups().size(), + "loaded state dict has a different number of parameter groups"); + for (const auto i : c10::irange(saved_param_groups.size())) { + std::vector param_group_old_keys = saved_param_groups[i].first; + std::vector params = optimizer.param_groups()[i].params(); + TORCH_CHECK( + param_group_old_keys.size() == params.size(), + "loaded state dict contains a parameter group that has a different size than the optimizer's parameter group"); + + for (const auto idx : c10::irange(params.size())) { + auto param_group_old_key = + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(std::stoull(param_group_old_keys[idx])); + if (saved_state.find(param_group_old_key) != saved_state.end()) { + optimizer.state()[params[idx].unsafeGetTensorImpl()] = + std::move(saved_state[param_group_old_key]); + } + } + + auto& saved_options = reinterpret_cast( + *saved_param_groups[i].second); + auto& current_options = reinterpret_cast( + optimizer.param_groups()[i].options()); + current_options = saved_options; + } +} + +/// Utility function to save a vector of buffers. +template +void serialize( + serialize::OutputArchive& archive, + const std::string& key, + const BufferContainer& buffers) { + archive.write( + key + "/size", torch::tensor(static_cast(buffers.size()))); + for (const auto index : c10::irange(buffers.size())) { + archive.write( + key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true); + } +} + +/// Utility function to load a vector of buffers. +template +void serialize( + serialize::InputArchive& archive, + const std::string& key, + BufferContainer& buffers) { + buffers.clear(); + torch::Tensor size_tensor; + archive.read(key + "/size", size_tensor); + const size_t size = size_tensor.item(); + for (const auto index : c10::irange(size)) { + buffers.emplace_back(); + archive.read( + key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true); + } +} + +template +c10::List deque_to_list(const std::deque& dq) { + c10::List list; + list.reserve(dq.size()); + for (const auto& e : dq) { + list.emplace_back(e); + } + return list; +} + +template +std::deque list_to_deque(const c10::List& list) { + std::deque dq; + for (const auto& e : list) { + dq.emplace_back(e); + } + return dq; +} + +#define _TORCH_OPTIM_SERIALIZE(name) \ + torch::optim::serialize(archive, #name, self.name) + +#define _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(OptimizerName) \ + torch::optim::serialize( \ + archive, self) + +#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG(name) \ + { \ + auto ivalue = torch::IValue(name()); \ + /* do not serialize if name is an undefined tensor*/ \ + if (!(ivalue.isTensor() && \ + ivalue.unsafeToTensorImpl() == \ + at::UndefinedTensorImpl::singleton())) { \ + archive.write(#name, ivalue); \ + } \ + } + +#define _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(name) \ + { \ + c10::IValue ivalue = torch::IValue(deque_to_list(name())); \ + archive.write(#name, ivalue); \ + } + +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) \ + { \ + c10::IValue ivalue; \ + bool exists = archive.try_read(#name, ivalue); \ + if (exists) { \ + name(ivalue.to()); \ + } else { \ + constexpr bool is_tensor_type = std::is_base_of_v; \ + TORCH_INTERNAL_ASSERT(is_tensor_type); \ + } \ + } + +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) \ + { \ + c10::IValue ivalue; \ + bool exists = archive.try_read(#name, ivalue); \ + if (exists) { \ + name(ivalue.toOptional()); \ + } \ + } + +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(T, name) \ + { \ + c10::IValue ivalue; \ + archive.read(#name, ivalue); \ + auto list = ivalue.to>(); \ + name(list_to_deque(list)); \ + } + +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/sgd.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/sgd.h new file mode 100644 index 0000000000000000000000000000000000000000..34896fb15653d2b6fae7e88f0fd8dfe6be4a2c4f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/optim/sgd.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::serialize { +class OutputArchive; +class InputArchive; +} // namespace torch::serialize + +namespace torch::optim { + +struct TORCH_API SGDOptions : public OptimizerCloneableOptions { + SGDOptions(double lr); + TORCH_ARG(double, lr); + TORCH_ARG(double, momentum) = 0; + TORCH_ARG(double, dampening) = 0; + TORCH_ARG(double, weight_decay) = 0; + TORCH_ARG(bool, nesterov) = false; + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const SGDOptions& lhs, + const SGDOptions& rhs); + double get_lr() const override; + void set_lr(const double lr) override; +}; + +struct TORCH_API SGDParamState + : public OptimizerCloneableParamState { + TORCH_ARG(torch::Tensor, momentum_buffer); + + public: + void serialize(torch::serialize::InputArchive& archive) override; + void serialize(torch::serialize::OutputArchive& archive) const override; + TORCH_API friend bool operator==( + const SGDParamState& lhs, + const SGDParamState& rhs); +}; + +class TORCH_API SGD : public Optimizer { + public: + explicit SGD( + const std::vector& param_groups, + SGDOptions defaults) + : Optimizer(param_groups, std::make_unique(defaults)) { + TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); + TORCH_CHECK( + defaults.momentum() >= 0, + "Invalid momentum value: ", + defaults.momentum()); + TORCH_CHECK( + defaults.weight_decay() >= 0, + "Invalid weight_decay value: ", + defaults.weight_decay()); + TORCH_CHECK( + !defaults.nesterov() || + (defaults.momentum() > 0 && defaults.dampening() == 0), + "Nesterov momentum requires a momentum and zero dampening"); + } + + explicit SGD(std::vector params, SGDOptions defaults) + : SGD({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} + + torch::Tensor step(LossClosure closure = nullptr) override; + + void save(serialize::OutputArchive& archive) const override; + void load(serialize::InputArchive& archive) override; + + private: + template + static void serialize(Self& self, Archive& archive) { + _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(SGD); + } +}; +} // namespace torch::optim diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/ordered_dict.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/ordered_dict.h new file mode 100644 index 0000000000000000000000000000000000000000..874c73b4ee06bea9f58c2ba7841f2426805714ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/ordered_dict.h @@ -0,0 +1,516 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch { +/// An ordered dictionary implementation, akin to Python's `OrderedDict`. +template +class OrderedDict { + public: + /// A (key, value) pair. + class Item; + + // The lifetime of an iterator is bound to the lifetime of the `OrderedDict`. + // Further, any `insert()` operation may invalidate all iterators + // pointing into the vector. + using Iterator = typename std::vector::iterator; + using ConstIterator = typename std::vector::const_iterator; + + /// Constructs the `OrderedDict` with a short description of the kinds of keys + /// stored in the `OrderedDict`. This description is used in error messages + /// thrown by the `OrderedDict`. + explicit OrderedDict(std::string key_description = "Key"); + + /// Copy constructs this `OrderedDict` from `other`. + OrderedDict(const OrderedDict& other); + + /// Assigns items from `other` to this `OrderedDict`. + OrderedDict& operator=(const OrderedDict& other); + + // NB: Move works by default, because you can move-construct vectors of const + // values. I tried to make this noexcept (conditional on the move constructors + // of index_ and items_ being noexcept) but the obvious spelling didn't + // compile on Windows. + OrderedDict(OrderedDict&& other) noexcept = default; + OrderedDict& operator=(OrderedDict&& other) noexcept = default; + + ~OrderedDict() = default; + + /// Constructs a new `OrderedDict` and pre-populates it with the given + /// `Item`s. + /*implicit */ OrderedDict(std::initializer_list initializer_list); + + /// Returns the key description string the `OrderedDict` was constructed with. + const std::string& key_description() const noexcept; + + // Element Access + + /// Returns the very first item in the `OrderedDict` and throws an exception + /// if it is empty. + Item& front(); + + /// Returns the very first item in the `OrderedDict` and throws an exception + /// if it is empty. + const Item& front() const; + + /// Returns the very last item in the `OrderedDict` and throws an exception + /// if it is empty. + Item& back(); + + /// Returns the very last item in the `OrderedDict` and throws an exception + /// if it is empty. + const Item& back() const; + + /// Returns the item at the `index`-th position in the `OrderedDict`. Throws + /// an exception if the index is out of bounds. + Item& operator[](size_t index); + + /// Returns the item at the `index`-th position in the `OrderedDict`. Throws + /// an exception if the index is out of bounds. + const Item& operator[](size_t index) const; + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `OrderedDict`. Use `find()` for a + /// non-throwing way of accessing a value if it is present. + Value& operator[](const Key& key); + + /// Returns the value associated with the given `key`. Throws an exception if + /// no such key is stored in the `OrderedDict`. Use `find()` for a + /// non-throwing way of accessing a value if it is present. + const Value& operator[](const Key& key) const; + + // Lookup + + /// Returns a pointer to the value associated with the given key, or a + /// `nullptr` if no such key is stored in the `OrderedDict`. + Value* find(const Key& key) noexcept; + + /// Returns a pointer to the value associated with the given key, or a + /// `nullptr` if no such key is stored in the `OrderedDict`. + const Value* find(const Key& key) const noexcept; + + /// Returns true if the key is present in the `OrderedDict`. + bool contains(const Key& key) const noexcept; + + // Iterators + + /// Returns an iterator to the first item in the `OrderedDict`. Iteration is + /// ordered. + Iterator begin(); + + /// Returns an iterator to the first item in the `OrderedDict`. Iteration is + /// ordered. + ConstIterator begin() const; + + /// Returns an iterator one past the last item in the `OrderedDict`. + Iterator end(); + + /// Returns an iterator one past the last item in the `OrderedDict`. + ConstIterator end() const; + + // Capacity + + /// Returns the number of items currently stored in the `OrderedDict`. + size_t size() const noexcept; + + /// Returns true if the `OrderedDict` contains no elements. + bool is_empty() const noexcept; + + /// Resizes internal storage to fit at least `requested_capacity` items + /// without requiring reallocation. + void reserve(size_t requested_capacity); + + // Modifiers + + /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an + /// exception if the key is already present. If insertion is successful, + /// immediately returns a reference to the inserted value. + template + Value& insert(K&& key, V&& value); + + /// Inserts a new `(key, value)` pair into the `OrderedDict`. Throws an + /// exception if the key is already present. If insertion is successful, + /// immediately returns a reference to the inserted value. + Value& insert(Key key, Value&& value); + + /// Inserts all items from `other` into this `OrderedDict`. If any key from + /// `other` is already present in this `OrderedDict`, an exception is thrown. + void update(OrderedDict&& other); + + /// Inserts all items from `other` into this `OrderedDict`. If any key from + /// `other` is already present in this `OrderedDict`, an exception is thrown. + void update(const OrderedDict& other); + + /// Removes the item that has `key` from this `OrderedDict` if exists and if + /// it doesn't an exception is thrown. + void erase(const Key& key); + + /// Removes all items from this `OrderedDict`. + void clear(); + + // Observers + + /// Returns the items stored in the `OrderedDict`. + const std::vector& items() const noexcept; + + /// Returns a newly allocated vector and copies all keys from this + /// `OrderedDict` into the vector. + ::std::vector keys() const; + + /// Returns a newly allocated vector and copies all values from this + /// `OrderedDict` into the vector. + ::std::vector values() const; + + /// Returns a newly allocated vector and copies all keys and values from this + /// `OrderedDict` into a vector of `std::pair`. + ::std::vector> pairs() const; + + /// Returns true if both dicts contain the same keys and values, in the same + /// order. + template + friend bool operator==( + const OrderedDict& a, + const OrderedDict& b); + + private: + /// A mapping from a key to an index into the `items_` vector. + ::std::unordered_map index_; + + /// The items stored in the `OrderedDict`. + ::std::vector items_; + + /// A description of the keys stored in the `OrderedDict`. + ::std::string key_description_{"Key"}; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict::Item ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +class OrderedDict::Item { + public: + /// Constructs a new item. + Item(Key key, Value value) : pair_(std::move(key), std::move(value)) {} + + /// Returns a reference to the value. + Value& operator*() { + return value(); + } + + /// Returns a reference to the value. + const Value& operator*() const { + return value(); + } + + /// Allows access to the value using the arrow operator. + Value* operator->() { + return &value(); + } + + /// Allows access to the value using the arrow operator. + const Value* operator->() const { + return &value(); + } + + /// Returns a reference to the key. + const Key& key() const noexcept { + return pair_.first; + } + + /// Returns a reference to the value. + Value& value() noexcept { + return pair_.second; + } + + /// Returns a reference to the value. + const Value& value() const noexcept { + return pair_.second; + } + + /// Returns a `(key, value)` pair. + const std::pair& pair() const noexcept { + return pair_; + } + + private: + /// This is stored as an std::pair because it will make Python binding a lot, + /// lot easier. + ::std::pair pair_; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +OrderedDict::OrderedDict(std::string key_description) + : key_description_(std::move(key_description)) {} + +template +OrderedDict::OrderedDict(const OrderedDict& other) + : index_(other.index_), key_description_(other.key_description_) { + // Copy we have to do ourselves, because items' keys are const, so we have to + // re-insert the items. + for (const auto& item : other.items_) { + items_.push_back(item); + } +} + +template +OrderedDict& OrderedDict::operator=( + const OrderedDict& other) { + index_ = other.index_; + items_.clear(); + for (auto& item : other.items_) { + items_.push_back(item); + } + key_description_ = other.key_description_; + return *this; +} + +template +OrderedDict::OrderedDict( + std::initializer_list initializer_list) + : OrderedDict("Key") { + items_.reserve(initializer_list.size()); + for (auto& item : initializer_list) { + // Copy the key here and move it into the index. + items_.emplace_back(item.key(), std::move(item.value())); + index_.emplace(std::move(item.key()), size() - 1); + } +} + +template +typename OrderedDict::Iterator OrderedDict::begin() { + return items_.begin(); +} + +template +typename OrderedDict::ConstIterator OrderedDict::begin() + const { + return items_.begin(); +} + +template +typename OrderedDict::Iterator OrderedDict::end() { + return items_.end(); +} + +template +typename OrderedDict::ConstIterator OrderedDict::end() + const { + return items_.end(); +} + +template +typename OrderedDict::Item& OrderedDict::front() { + TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict"); + return items_.front(); +} + +template +const typename OrderedDict::Item& OrderedDict::front() + const { + TORCH_CHECK(!items_.empty(), "Called front() on an empty OrderedDict"); + return items_.front(); +} + +template +typename OrderedDict::Item& OrderedDict::back() { + TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict"); + return items_.back(); +} + +template +const typename OrderedDict::Item& OrderedDict::back() + const { + TORCH_CHECK(!items_.empty(), "Called back() on an empty OrderedDict"); + return items_.back(); +} + +template +typename OrderedDict::Item& OrderedDict::operator[]( + size_t index) { + TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds"); + return items_[index]; +} + +template +const typename OrderedDict::Item& OrderedDict:: +operator[](size_t index) const { + TORCH_CHECK(index < items_.size(), "Index ", index, " is out of bounds"); + return items_[index]; +} + +template +Value& OrderedDict::operator[](const Key& key) { + if (auto* value = find(key)) { + return *value; + } + TORCH_CHECK(false, key_description_, " '", key, "' is not defined"); +} + +template +const Value& OrderedDict::operator[](const Key& key) const { + if (auto* value = find(key)) { + return *value; + } + TORCH_CHECK(false, key_description_, " '", key, "' is not defined"); +} + +template +template +Value& OrderedDict::insert(K&& key, V&& value) { + TORCH_CHECK( + index_.count(key) == 0, key_description_, " '", key, "' already defined"); + // Copy `key` here and move it into the index. + items_.emplace_back(key, std::forward(value)); + index_.emplace(std::forward(key), size() - 1); + return items_.back().value(); +} + +template +Value& OrderedDict::insert(Key key, Value&& value) { + return insert(std::move(key), std::move(value)); +} + +template +void OrderedDict::update(OrderedDict&& other) { + reserve(size() + other.size()); + for (auto&& item : std::move(other)) { + // We want to call `insert()` to prevent duplicate keys. + insert(std::move(item.key()), std::move(item.value())); + } +} + +template +void OrderedDict::update(const OrderedDict& other) { + reserve(size() + other.size()); + for (auto& item : other) { + // We want to call `insert()` to prevent duplicate keys. + insert(item.key(), item.value()); + } +} + +template +Value* OrderedDict::find(const Key& key) noexcept { + auto iterator = index_.find(key); + if (iterator == index_.end()) { + return nullptr; + } + return &items_[iterator->second].value(); +} + +template +const Value* OrderedDict::find(const Key& key) const noexcept { + auto iterator = index_.find(key); + if (iterator == index_.end()) { + return nullptr; + } + return &items_[iterator->second].value(); +} + +template +void OrderedDict::erase(const Key& key) { + auto it = index_.find(key); + TORCH_CHECK(it != index_.end(), "Key '", key, "' doesn't exist"); + + auto index = it->second; + index_.erase(it); + items_.erase(items_.begin() + index); + + for (auto& pair : index_) + if (pair.second > index) + --pair.second; +} + +template +bool OrderedDict::contains(const Key& key) const noexcept { + return find(key) != nullptr; +} + +template +void OrderedDict::clear() { + index_.clear(); + items_.clear(); +} + +template +size_t OrderedDict::size() const noexcept { + return items_.size(); +} + +template +bool OrderedDict::is_empty() const noexcept { + return items_.empty(); +} + +template +const std::string& OrderedDict::key_description() const noexcept { + return key_description_; +} + +template +const std::vector::Item>& OrderedDict< + Key, + Value>::items() const noexcept { + return items_; +} + +template +::std::vector OrderedDict::keys() const { + std::vector keys; + keys.reserve(size()); + for (const auto& item : items_) { + keys.push_back(item.key()); + } + return keys; +} + +template +::std::vector OrderedDict::values() const { + std::vector values; + values.reserve(size()); + for (const auto& item : items_) { + values.push_back(item.value()); + } + return values; +} + +template +::std::vector> OrderedDict::pairs() const { + std::vector> values; + values.reserve(size()); + for (const auto& item : items_) { + values.push_back(item.pair()); + } + return values; +} + +template +void OrderedDict::reserve(size_t requested_capacity) { + index_.reserve(requested_capacity); + items_.reserve(requested_capacity); +} + +template +bool operator==( + const torch::OrderedDict& a, + const torch::OrderedDict& b) { + using Item = typename torch::OrderedDict::Item; + if (a.index_ != b.index_) + return false; + if (a.items_.size() != b.items_.size()) + return false; + // NOTE: There's no point in comparing keys for items_, as we already know + // that index is equal. + return std::equal( + a.items_.begin(), + a.items_.end(), + b.items_.begin(), + [](const Item& a, const Item& b) { return a.value() == b.value(); }); +} + +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/python.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/python.h new file mode 100644 index 0000000000000000000000000000000000000000..1d65bc221fd50bf883f1ddd245c9a3e704309f8d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/python.h @@ -0,0 +1,259 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::python { +namespace detail { +inline Device py_object_to_device(py::object object) { + PyObject* obj = object.ptr(); + if (THPDevice_Check(obj)) { + return reinterpret_cast(obj)->device; + } + throw TypeError("Expected device"); +} + +inline Dtype py_object_to_dtype(py::object object) { + PyObject* obj = object.ptr(); + if (THPDtype_Check(obj)) { + return reinterpret_cast(obj)->scalar_type; + } + throw TypeError("Expected dtype"); +} + +template +using PyModuleClass = + py::class_>; + +/// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also +/// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module +/// to which it delegates all calls. +template +void bind_cpp_module_wrapper( + const py::module& module, + PyModuleClass cpp_class, + const char* name) { + // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass + // with a dynamically created class below. + py::object cpp_module = + py::module::import("torch.nn.cpp").attr("ModuleWrapper"); + + // Grab the `type` class which we'll use as a metaclass to create a new class + // dynamically. + py::object type_metaclass = + py::reinterpret_borrow((PyObject*)&PyType_Type); + + // The `ModuleWrapper` constructor copies all functions to its own `__dict__` + // in its constructor, but we do need to give our dynamic class a constructor. + // Inside, we construct an instance of the original C++ module we're binding + // (the `torch::nn::Module` subclass), and then forward it to the + // `ModuleWrapper` constructor. + py::dict attributes; + + // `type()` always needs a `str`, but pybind11's `str()` method always creates + // a `unicode` object. + py::object name_str = py::str(name); + + // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of + // `torch.nn.Module`, and will delegate all calls to the C++ module we're + // binding. + py::object wrapper_class = + type_metaclass(name_str, py::make_tuple(cpp_module), attributes); + + // The constructor of the dynamic class calls `ModuleWrapper.__init__()`, + // which replaces its methods with those of the C++ module. + wrapper_class.attr("__init__") = py::cpp_function( + [cpp_module, cpp_class]( + const py::object& self, + const py::args& args, + const py::kwargs& kwargs) { + cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs)); + }, + py::is_method(wrapper_class)); + + // Calling `my_module.my_class` now means that `my_class` is a subclass of + // `ModuleWrapper`, and whose methods call into the C++ module we're binding. + module.attr(name) = wrapper_class; +} +} // namespace detail + +/// Adds method bindings for a pybind11 `class_` that binds an `nn::Module` +/// subclass. +/// +/// Say you have a pybind11 class object created with `py::class_(m, +/// "Net")`. This function will add all the necessary `.def()` calls to bind the +/// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into +/// Python. +/// +/// Users should prefer to use `bind_module` if possible. +template +py::class_ add_module_bindings( + py::class_ module) { + // clang-format off + return module + .def("train", + [](ModuleType& module, bool mode) { module.train(mode); }, + py::arg("mode") = true) + .def("eval", [](ModuleType& module) { module.eval(); }) + .def("clone", [](ModuleType& module) { return module.clone(); }) + .def_property_readonly( + "training", [](ModuleType& module) { return module.is_training(); }) + .def("zero_grad", [](ModuleType& module) { module.zero_grad(); }) + .def_property_readonly( "_parameters", [](ModuleType& module) { + return module.named_parameters(/*recurse=*/false); + }) + .def("parameters", [](ModuleType& module, bool recurse) { + return module.parameters(recurse); + }, + py::arg("recurse") = true) + .def("named_parameters", [](ModuleType& module, bool recurse) { + return module.named_parameters(recurse); + }, + py::arg("recurse") = true) + .def_property_readonly("_buffers", [](ModuleType& module) { + return module.named_buffers(/*recurse=*/false); + }) + .def("buffers", [](ModuleType& module, bool recurse) { + return module.buffers(recurse); }, + py::arg("recurse") = true) + .def("named_buffers", [](ModuleType& module, bool recurse) { + return module.named_buffers(recurse); + }, + py::arg("recurse") = true) + .def_property_readonly( + "_modules", [](ModuleType& module) { return module.named_children(); }) + .def("modules", [](ModuleType& module) { return module.modules(); }) + .def("named_modules", + [](ModuleType& module, const py::object& /* unused */, std::string prefix, bool remove_duplicate /* unused */) { + return module.named_modules(std::move(prefix)); + }, + py::arg("memo") = py::none(), + py::arg("prefix") = std::string(), + py::arg("remove_duplicate") = true) + .def("children", [](ModuleType& module) { return module.children(); }) + .def("named_children", + [](ModuleType& module) { return module.named_children(); }) + .def("to", [](ModuleType& module, py::object object, bool non_blocking) { + if (THPDevice_Check(object.ptr())) { + module.to( + reinterpret_cast(object.ptr())->device, + non_blocking); + } else { + module.to(detail::py_object_to_dtype(object), non_blocking); + } + }, + py::arg("dtype_or_device"), + py::arg("non_blocking") = false) + .def("to", + [](ModuleType& module, + const py::object& device, + const py::object& dtype, + bool non_blocking) { + if (device.is_none()) { + module.to(detail::py_object_to_dtype(dtype), non_blocking); + } else if (dtype.is_none()) { + module.to(detail::py_object_to_device(device), non_blocking); + } else { + module.to( + detail::py_object_to_device(device), + detail::py_object_to_dtype(dtype), + non_blocking); + } + }, + py::arg("device"), + py::arg("dtype"), + py::arg("non_blocking") = false) + .def("cuda", [](ModuleType& module) { module.to(kCUDA); }) + .def("cpu", [](ModuleType& module) { module.to(kCPU); }) + .def("float", [](ModuleType& module) { module.to(kFloat32); }) + .def("double", [](ModuleType& module) { module.to(kFloat64); }) + .def("half", [](ModuleType& module) { module.to(kFloat16); }) + .def("__str__", [](ModuleType& module) { return module.name(); }) + .def("__repr__", [](ModuleType& module) { return module.name(); }); + // clang-format on +} + +/// Creates a pybind11 class object for an `nn::Module` subclass type and adds +/// default bindings. +/// +/// After adding the default bindings, the class object is returned, such that +/// you can add more bindings. +/// +/// Example usage: +/// \rst +/// .. code-block:: cpp +/// +/// struct Net : torch::nn::Module { +/// Net(int in, int out) { } +/// torch::Tensor forward(torch::Tensor x) { return x; } +/// }; +/// +/// PYBIND11_MODULE(my_module, m) { +/// torch::python::bind_module(m, "Net") +/// .def(py::init()) +/// .def("forward", &Net::forward); +/// } +/// \endrst +template +std::enable_if_t< + !torch::detail::has_forward::value || force_enable, + detail::PyModuleClass> +bind_module(py::module module, const char* name) { + py::module cpp = module.def_submodule("cpp"); + auto cpp_class = + add_module_bindings(detail::PyModuleClass(cpp, name)); + detail::bind_cpp_module_wrapper(module, cpp_class, name); + return cpp_class; +} + +/// Creates a pybind11 class object for an `nn::Module` subclass type and adds +/// default bindings. +/// +/// After adding the default bindings, the class object is returned, such that +/// you can add more bindings. +/// +/// If the class has a `forward()` method, it is automatically exposed as +/// `forward()` and `__call__` in Python. +/// +/// Example usage: +/// \rst +/// .. code-block:: cpp +/// +/// struct Net : torch::nn::Module { +/// Net(int in, int out) { } +/// torch::Tensor forward(torch::Tensor x) { return x; } +/// }; +/// +/// PYBIND11_MODULE(my_module, m) { +/// torch::python::bind_module(m, "Net") +/// .def(py::init()) +/// .def("forward", &Net::forward); +/// } +/// \endrst +template < + typename ModuleType, + typename = std::enable_if_t::value>> +detail::PyModuleClass bind_module( + py::module module, + const char* name) { + return bind_module(module, name) + .def("forward", &ModuleType::forward) + .def("__call__", &ModuleType::forward); +} +} // namespace torch::python diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/python/init.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/python/init.h new file mode 100644 index 0000000000000000000000000000000000000000..03edca27f47058cf4843b1915ec06d822983327f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/python/init.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace torch::python { +/// Initializes Python bindings for the C++ frontend. +void init_bindings(PyObject* module); +} // namespace torch::python diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..f320542499cce8498649b90ad8a901c86937d739 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize.h @@ -0,0 +1,144 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch { + +/// Serializes the given `value`. +/// There must be an overload of `operator<<` between `serialize::OutputArchive` +/// and `Value` for this method to be well-formed. Currently, such an overload +/// is provided for (subclasses of): +/// +/// - `torch::nn::Module`, +/// - `torch::optim::Optimizer` +/// - `torch::Tensor` +/// +/// To perform the serialization, a `serialize::OutputArchive` is constructed, +/// and all arguments after the `value` are forwarded to its `save_to` method. +/// For example, you can pass a filename, or an `ostream`. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::Linear model(3, 4); +/// torch::save(model, "model.pt"); +/// +/// torch::optim::SGD sgd(model->parameters(), 0.9); // 0.9 is learning rate +/// std::ostringstream stream; +/// // Note that the same stream cannot be used in multiple torch::save(...) +/// // invocations, otherwise the header will be corrupted. +/// torch::save(sgd, stream); +/// +/// auto tensor = torch::ones({3, 4}); +/// torch::save(tensor, "my_tensor.pt"); +/// \endrst +template +void save(const Value& value, SaveToArgs&&... args) { + serialize::OutputArchive archive(std::make_shared()); + archive << value; + archive.save_to(std::forward(args)...); +} + +/// Serializes the given `tensor_vec` of type `std::vector`. +/// +/// To perform the serialization, a `serialize::OutputArchive` is constructed, +/// and all arguments after the `tensor_vec` are forwarded to its `save_to` +/// method. For example, you can pass a filename, or an `ostream`. +/// +/// \rst +/// .. code-block:: cpp +/// +/// std::vector tensor_vec = { torch::randn({1, 2}), +/// torch::randn({3, 4}) }; torch::save(tensor_vec, "my_tensor_vec.pt"); +/// +/// std::vector tensor_vec = { torch::randn({5, 6}), +/// torch::randn({7, 8}) }; std::ostringstream stream; +/// // Note that the same stream cannot be used in multiple torch::save(...) +/// // invocations, otherwise the header will be corrupted. +/// torch::save(tensor_vec, stream); +/// \endrst +template +void save(const std::vector& tensor_vec, SaveToArgs&&... args) { + serialize::OutputArchive archive(std::make_shared()); + for (const auto i : c10::irange(tensor_vec.size())) { + auto& value = tensor_vec[i]; + archive.write(std::to_string(i), value); + } + archive.save_to(std::forward(args)...); +} + +TORCH_API std::vector pickle_save(const torch::IValue& ivalue); +TORCH_API torch::IValue pickle_load(const std::vector& data); + +/// Deserializes the given `value`. +/// There must be an overload of `operator>>` between `serialize::InputArchive` +/// and `Value` for this method to be well-formed. Currently, such an overload +/// is provided for (subclasses of): +/// +/// - `torch::nn::Module`, +/// - `torch::optim::Optimizer` +/// - `torch::Tensor` +/// +/// To perform the serialization, a `serialize::InputArchive` is constructed, +/// and all arguments after the `value` are forwarded to its `load_from` method. +/// For example, you can pass a filename, or an `istream`. +/// +/// \rst +/// .. code-block:: cpp +/// +/// torch::nn::Linear model(3, 4); +/// torch::load(model, "model.pt"); +/// +/// torch::optim::SGD sgd(model->parameters(), 0.9); // 0.9 is learning rate +/// std::istringstream stream("..."); +/// torch::load(sgd, stream); +/// +/// auto tensor = torch::ones({3, 4}); +/// torch::load(tensor, "my_tensor.pt"); +/// \endrst +template +void load(Value& value, LoadFromArgs&&... args) { + serialize::InputArchive archive; + archive.load_from(std::forward(args)...); + archive >> value; +} + +/// Deserializes the given `tensor_vec` of type `std::vector`. +/// +/// To perform the serialization, a `serialize::InputArchive` is constructed, +/// and all arguments after the `value` are forwarded to its `load_from` method. +/// For example, you can pass a filename, or an `istream`. +/// +/// \rst +/// .. code-block:: cpp +/// +/// std::vector tensor_vec; +/// torch::load(tensor_vec, "my_tensor_vec.pt"); +/// +/// std::vector tensor_vec; +/// std::istringstream stream("..."); +/// torch::load(tensor_vec, stream); +/// \endrst +template +void load(std::vector& tensor_vec, LoadFromArgs&&... args) { + serialize::InputArchive archive; + archive.load_from(std::forward(args)...); + + // NOTE: The number of elements in the serialized `std::vector` + // is not known ahead of time, so we need a while-loop to increment the index, + // and use `archive.try_read(...)` to check whether we have reached the end of + // the serialized `std::vector`. + size_t index = 0; + torch::Tensor value; + while (archive.try_read(std::to_string(index), value)) { + tensor_vec.push_back(std::move(value)); + value = torch::Tensor(); + index++; + } +} +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/archive.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/archive.h new file mode 100644 index 0000000000000000000000000000000000000000..d4ebe8e9d54cc127dd2df4ad1ccbcd226b037326 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/archive.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/input-archive.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/input-archive.h new file mode 100644 index 0000000000000000000000000000000000000000..6495d532c32ce13e44852569774f374082c002b6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/input-archive.h @@ -0,0 +1,115 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { +class Tensor; +} // namespace at + +namespace torch { +using at::Tensor; +namespace jit { +struct Module; +} // namespace jit +} // namespace torch + +namespace torch::serialize { + +/// A recursive representation of tensors that can be deserialized from a file +/// or stream. In most cases, users should not have to interact with this class, +/// and should instead use `torch::load`. +class TORCH_API InputArchive final { + public: + /// Default-constructs the `InputArchive`. + InputArchive(); + + // Move is allowed. + InputArchive(InputArchive&&) = default; + InputArchive& operator=(InputArchive&&) = default; + + // Copy is disallowed. + InputArchive(InputArchive&) = delete; + InputArchive& operator=(InputArchive&) = delete; + + ~InputArchive() = default; + + /// Reads an `IValue` associated with a given `key`. + void read(const std::string& key, c10::IValue& ivalue); + + /// Reads an `IValue` associated with a given `key`. If there is no `IValue` + /// associated with the `key`, this returns false, otherwise it returns true. + bool try_read(const std::string& key, c10::IValue& ivalue); + + /// Reads a `tensor` associated with a given `key`. If there is no `tensor` + /// associated with the `key`, this returns false, otherwise it returns true. + /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` + /// must be `true`. + bool try_read(const std::string& key, Tensor& tensor, bool is_buffer = false); + + /// Reads a `tensor` associated with a given `key`. + /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` + /// must be `true`. + void read(const std::string& key, Tensor& tensor, bool is_buffer = false); + + /// Reads a `InputArchive` associated with a given `key`. If there is no + /// `InputArchive` associated with the `key`, this returns false, otherwise + /// it returns true. + bool try_read(const std::string& key, InputArchive& archive); + + /// Reads an `InputArchive` associated with a given `key`. + /// The archive can thereafter be used for further deserialization of the + /// nested data. + void read(const std::string& key, InputArchive& archive); + + /// Loads the `InputArchive` from a serialized representation stored in the + /// file at `filename`. Storage are remapped using device option. If device + /// is not specified, the module is loaded to the original device. + void load_from( + const std::string& filename, + std::optional device = std::nullopt); + + /// Loads the `InputArchive` from a serialized representation stored in the + /// given `stream`. Storage are remapped using device option. If device + /// is not specified, the module is loaded to the original device. + void load_from( + std::istream& stream, + std::optional device = std::nullopt); + + // Loads given the specified flat array. + void load_from( + const char* data, + size_t size, + std::optional device = std::nullopt); + + // Loads given the specified read and size functions. + void load_from( + const std::function& + read_func, + const std::function& size_func, + std::optional device = std::nullopt); + + // Returns the vector of keys in the input archive. + std::vector keys(); + + /// Forwards all arguments to `read()`. + /// Useful for generic code that can be reused for both `InputArchive` and + /// `OutputArchive` (where `operator()` forwards to `write()`). + template + void operator()(Ts&&... ts) { + read(std::forward(ts)...); + } + + private: + jit::Module module_; + std::string hierarchy_prefix_; +}; +} // namespace torch::serialize diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/output-archive.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/output-archive.h new file mode 100644 index 0000000000000000000000000000000000000000..f47aca4df95a51aa1ae4de98d487fbb947146b7b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/output-archive.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace at { +class Tensor; +} // namespace at + +namespace torch { +using at::Tensor; +namespace jit { +struct Module; +} // namespace jit +} // namespace torch + +namespace torch::serialize { +class TORCH_API OutputArchive final { + public: + explicit OutputArchive(std::shared_ptr cu); + explicit OutputArchive() + : cu_(std::make_shared()), + module_("__torch__.Module", cu_) {} + + // Move is allowed. + OutputArchive(OutputArchive&&) = default; + OutputArchive& operator=(OutputArchive&&) = default; + + // Copy is disallowed. + OutputArchive(OutputArchive&) = delete; + OutputArchive& operator=(OutputArchive&) = delete; + + std::shared_ptr compilation_unit() const { + return cu_; + } + + /// Writes an `IValue` to the `OutputArchive`. + void write(const std::string& key, const c10::IValue& ivalue); + + /// Writes a `(key, tensor)` pair to the `OutputArchive`, and marks it as + /// being or not being a buffer (non-differentiable tensor). + void write( + const std::string& key, + const Tensor& tensor, + bool is_buffer = false); + + /// Writes a nested `OutputArchive` under the given `key` to this + /// `OutputArchive`. + void write(const std::string& key, OutputArchive& nested_archive); + + /// Saves the `OutputArchive` into a serialized representation in a file at + /// `filename`. + void save_to(const std::string& filename); + + /// Saves the `OutputArchive` into a serialized representation into the given + /// `stream`. + void save_to(std::ostream& stream); + + /// Saves the `OutputArchive` into a serialized representation using the + /// given writer function. + void save_to(const std::function& func); + + /// Forwards all arguments to `write()`. + /// Useful for generic code that can be reused for both `OutputArchive` and + /// `InputArchive` (where `operator()` forwards to `read()`). + template + void operator()(Ts&&... ts) { + write(std::forward(ts)...); + } + + private: + std::shared_ptr cu_; + jit::Module module_; +}; +} // namespace torch::serialize diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/tensor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..9f77ed170db32a497c23feed05aae8c266ab282e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/serialize/tensor.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace torch { +inline serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const Tensor& tensor) { + archive.write("0", tensor); + return archive; +} + +inline serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + Tensor& tensor) { + archive.read("0", tensor); + return archive; +} +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/sparse.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/sparse.h new file mode 100644 index 0000000000000000000000000000000000000000..753a07de8a6f07631d536930845ae749a6309738 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/sparse.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/special.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/special.h new file mode 100644 index 0000000000000000000000000000000000000000..7ab96c123f4a2a21fce390558a66ff6814d6a9ba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/special.h @@ -0,0 +1,1403 @@ +#pragma once + +#include +#include + +namespace torch::special { + +/// Computes the natural logarithm of the absolute value of the gamma function +/// See https://pytorch.org/docs/main/special.html#torch.special.gammaln. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::gammaln(t); +/// ``` +inline Tensor gammaln(const Tensor& self) { + return torch::special_gammaln(self); +} + +inline Tensor& gammaln_out(Tensor& result, const Tensor& self) { + return torch::special_gammaln_out(result, self); +} + +/// Computes the regularized lower incomplete gamma function +/// See https://pytorch.org/docs/main/special.html#torch.special.gammainc. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// auto s = torch::randn(128, dtype=kDouble); +/// torch::special::gammainc(s, t); +/// ``` +inline Tensor gammainc(const Tensor& self, const Tensor& other) { + return torch::special_gammainc(self, other); +} + +inline Tensor& gammainc_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { + return torch::special_gammainc_out(result, self, other); +} + +/// Computes the regularized upper incomplete gamma function +/// See https://pytorch.org/docs/main/special.html#torch.special.gammainc. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// auto s = torch::randn(128, dtype=kDouble); +/// torch::special::gammaincc(s, t); +/// ``` +inline Tensor gammaincc(const Tensor& self, const Tensor& other) { + return torch::special_gammaincc(self, other); +} + +inline Tensor& gammaincc_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { + return torch::special_gammaincc_out(result, self, other); +} + +/// Computes the multivariate log-gamma function with dimension `p`, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.multigammaln. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::multigammaln(t, 1); +/// ``` +inline Tensor multigammaln(const Tensor& self, int64_t p) { + return torch::special_multigammaln(self, p); +} + +inline Tensor& multigammaln_out(Tensor& result, const Tensor& self, int64_t p) { + return torch::special_multigammaln_out(result, self, p); +} + +/// Computes the nth derivative of the digamma function on the input. +/// See https:://pytorch.org/docs/main/special.html#torch.special.polygamma. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::polygamma(2, t); +/// ``` +inline Tensor polygamma(int64_t n, const Tensor& self) { + return torch::special_polygamma(n, self); +} + +inline Tensor& polygamma_out(Tensor& result, int64_t n, const Tensor& self) { + return torch::special_polygamma_out(result, n, self); +} + +/// Computes the logarithmic derivative of the gamma function on input +/// See https://pytorch.org/docs/main/special.html#torch.special.psi +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::psi(t); +/// ``` +inline Tensor psi(const Tensor& self) { + return torch::special_psi(self); +} + +inline Tensor& psi_out(Tensor& result, const Tensor& self) { + return torch::special_psi_out(result, self); +} + +/// Computes the logarithmic derivative of the gamma function on input +/// See https://pytorch.org/docs/main/special.html#torch.special.digamma +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::digamma(t); +/// ``` +inline Tensor digamma(const Tensor& self) { + return torch::special_digamma(self); +} + +inline Tensor& digamma_out(Tensor& result, const Tensor& self) { + return torch::special_digamma_out(result, self); +} + +/// Computes entropy of input, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.entr. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::entr(t); +/// ``` +inline Tensor entr(const Tensor& self) { + return torch::special_entr(self); +} + +inline Tensor& entr_out(Tensor& result, const Tensor& self) { + return torch::special_entr_out(result, self); +} + +/// Computes the error function +/// See https://pytorch.org/docs/main/special.html#torch.special.erf. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::erf(t); +/// ``` +inline Tensor erf(const Tensor& self) { + return torch::special_erf(self); +} + +inline Tensor& erf_out(Tensor& result, const Tensor& self) { + return torch::special_erf_out(result, self); +} + +/// Computes the complementary error function +/// See https://pytorch.org/docs/main/special.html#torch.special.erfc. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::erfc(t); +/// ``` +inline Tensor erfc(const Tensor& self) { + return torch::special_erfc(self); +} + +inline Tensor& erfc_out(Tensor& result, const Tensor& self) { + return torch::special_erfc_out(result, self); +} + +/// Computes the scaled complementary error function +/// See https://pytorch.org/docs/main/special.html#torch.special.erfcx. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::erfcx(t); +/// ``` +inline Tensor erfcx(const Tensor& self) { + return torch::special_erfcx(self); +} + +inline Tensor& erfcx_out(Tensor& result, const Tensor& self) { + return torch::special_erfcx_out(result, self); +} + +/// Computes the inverse error function +/// See https://pytorch.org/docs/main/special.html#torch.special.erfinv. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::erfinv(t); +/// ``` +inline Tensor erfinv(const Tensor& self) { + return torch::special_erfinv(self); +} + +inline Tensor& erfinv_out(Tensor& result, const Tensor& self) { + return torch::special_erfinv_out(result, self); +} + +/// Computes the log of summed exponentials of each row of input in the given +/// dimension dim See +/// https://pytorch.org/docs/main/special.html#torch.special.logsumexp. +/// +/// Example: +/// ``` +/// auto t = torch::randn(3, 3); +/// torch::special::logsumexp(t, 1); +/// ``` +inline Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) { + return torch::special_logsumexp(self, dims, keepdim); +} + +inline Tensor& logsumexp_out( + Tensor& result, + const Tensor& self, + IntArrayRef dims, + bool keepdim) { + return torch::special_logsumexp_out(result, self, dims, keepdim); +} + +/// Computes the argument, x, for which the area under the Gaussian probability +/// density function (integrated from minus infinity to x) is equal to input, +/// elementwise. See +/// https://pytorch.org/docs/main/special.html#torch.special.ndtri +/// +/// Example: +/// ``` +/// auto t = torch::rand(128, dtype=kDouble); +/// torch::special::ndtri(t); +/// ``` +inline Tensor ndtri(const Tensor& self) { + return torch::special_ndtri(self); +} + +inline Tensor& ndtri_out(Tensor& result, const Tensor& self) { + return torch::special_ndtri_out(result, self); +} + +/// Computes the log of area under the standard Gaussian probability density +/// function, integrated from minus infinity to :attr:`input`, elementwise See +/// https://pytorch.org/docs/main/special.html#torch.special.log_ndtr +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::log_ndtr(t); +/// ``` +inline Tensor log_ndtr(const Tensor& self) { + return torch::special_log_ndtr(self); +} + +inline Tensor& log_ndtr_out(Tensor& result, const Tensor& self) { + return torch::special_log_ndtr_out(result, self); +} + +/// Computes the logit of input, elementwise. +/// See https://pytorch.org/docs/main/special.html#torch.special.logit. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::logit(t); +/// ``` +inline Tensor logit(const Tensor& self) { + return torch::special_logit(self); +} + +inline Tensor& logit_out(Tensor& result, const Tensor& self) { + return torch::special_logit_out(result, self); +} + +/// Computes the expit (also known as the logistic sigmoid function) of input, +/// elementwise See +/// https://pytorch.org/docs/main/special.html#torch.special.expit. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::expit(t); +/// ``` +inline Tensor expit(const Tensor& self) { + return torch::special_expit(self); +} + +inline Tensor& expit_out(Tensor& result, const Tensor& self) { + return torch::special_expit_out(result, self); +} + +/// Computes the base two exponential function of :attr:`input`, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.exp2. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::exp2(t); +/// ``` +inline Tensor exp2(const Tensor& self) { + return torch::special_exp2(self); +} + +inline Tensor& exp2_out(Tensor& result, const Tensor& self) { + return torch::special_exp2_out(result, self); +} + +/// Computes the exponential of the elements minus 1, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.expm1. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::expm1(t); +/// ``` +inline Tensor expm1(const Tensor& self) { + return torch::special_expm1(self); +} + +inline Tensor& expm1_out(Tensor& result, const Tensor& self) { + return torch::special_expm1_out(result, self); +} + +/// Computes x * log(y) for inputs, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.xlogy. +/// +/// Example: +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto y = torch::randn(128, dtype=kDouble); +/// torch::special::xlogy(x, y); +/// ``` +inline Tensor xlogy(const Tensor& self, const Tensor& other) { + return torch::special_xlogy(self, other); +} + +inline Tensor xlogy(const Scalar& self, const Tensor& other) { + return torch::special_xlogy(self, other); +} + +inline Tensor xlogy(const Tensor& self, const Scalar& other) { + return torch::special_xlogy(self, other); +} + +inline Tensor& xlogy_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { + return torch::special_xlogy_out(result, self, other); +} + +inline Tensor& xlogy_out( + Tensor& result, + const Scalar& self, + const Tensor& other) { + return torch::special_xlogy_out(result, self, other); +} + +inline Tensor& xlogy_out( + Tensor& result, + const Tensor& self, + const Scalar& other) { + return torch::special_xlogy_out(result, self, other); +} + +/// Computes x * log1p(y) for inputs, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.xlog1py. +/// +/// Example: +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto y = torch::randn(128, dtype=kDouble); +/// torch::special::xlog1py(x, y); +/// ``` +inline Tensor xlog1py(const Tensor& self, const Tensor& other) { + return torch::special_xlog1py(self, other); +} + +inline Tensor xlog1py(const Scalar& self, const Tensor& other) { + return torch::special_xlog1py(self, other); +} + +inline Tensor xlog1py(const Tensor& self, const Scalar& other) { + return torch::special_xlog1py(self, other); +} + +inline Tensor& xlog1py_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { + return torch::special_xlog1py_out(result, self, other); +} + +inline Tensor& xlog1py_out( + Tensor& result, + const Scalar& self, + const Tensor& other) { + return torch::special_xlog1py_out(result, self, other); +} + +inline Tensor& xlog1py_out( + Tensor& result, + const Tensor& self, + const Scalar& other) { + return torch::special_xlog1py_out(result, self, other); +} + +/// Computes Hurwitz Zeta function for inputs, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.zeta. +/// +/// Example: +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto y = torch::randn(128, dtype=kDouble); +/// torch::special::zeta(x, y); +/// ``` +inline Tensor zeta(const Tensor& self, const Tensor& other) { + return torch::special_zeta(self, other); +} + +inline Tensor zeta(const Scalar& self, const Tensor& other) { + return torch::special_zeta(self, other); +} + +inline Tensor zeta(const Tensor& self, const Scalar& other) { + return torch::special_zeta(self, other); +} + +inline Tensor& zeta_out( + Tensor& result, + const Tensor& self, + const Tensor& other) { + return torch::special_zeta_out(result, self, other); +} + +inline Tensor& zeta_out( + Tensor& result, + const Scalar& self, + const Tensor& other) { + return torch::special_zeta_out(result, self, other); +} + +inline Tensor& zeta_out( + Tensor& result, + const Tensor& self, + const Scalar& other) { + return torch::special_zeta_out(result, self, other); +} + +/// Computes the zeroth order modified Bessel function of the first kind of +/// input, elementwise See +/// https://pytorch.org/docs/main/special.html#torch.special.i0 +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::i0(t); +/// ``` +inline Tensor i0(const Tensor& self) { + return torch::special_i0(self); +} + +inline Tensor& i0_out(Tensor& result, const Tensor& self) { + return torch::special_i0_out(result, self); +} + +/// Computes the area under the standard Gaussian probability density function, +/// integrated from minus infinity to :attr:`input`, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.ndtr +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::ndtr(t); +/// ``` +inline Tensor ndtr(const Tensor& self) { + return torch::special_ndtr(self); +} + +inline Tensor& ndtr_out(Tensor& result, const Tensor& self) { + return torch::special_ndtr_out(result, self); +} + +/// Computes the exponentially scaled zeroth order modified Bessel function of +/// the first kind See +/// https://pytorch.org/docs/main/special.html#torch.special.i0e. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::i0e(t); +/// ``` +inline Tensor i0e(const Tensor& self) { + return torch::special_i0e(self); +} + +inline Tensor& i0e_out(Tensor& result, const Tensor& self) { + return torch::special_i0e_out(result, self); +} + +/// Computes the first order modified Bessel function of the first kind +/// See https://pytorch.org/docs/main/special.html#torch.special.i1. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::i1(t); +/// ``` +inline Tensor i1(const Tensor& self) { + return torch::special_i1(self); +} + +inline Tensor& i1_out(Tensor& result, const Tensor& self) { + return torch::special_i1_out(result, self); +} + +/// Computes the exponentially scaled first order modified Bessel function of +/// the first kind See +/// https://pytorch.org/docs/main/special.html#torch.special.i1e. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::i1e(t); +/// ``` +inline Tensor i1e(const Tensor& self) { + return torch::special_i1e(self); +} + +inline Tensor& i1e_out(Tensor& result, const Tensor& self) { + return torch::special_i1e_out(result, self); +} + +/// Computes the sinc of input, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.sinc. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::sinc(t); +/// ``` +inline Tensor sinc(const Tensor& self) { + return torch::special_sinc(self); +} + +inline Tensor& sinc_out(Tensor& result, const Tensor& self) { + return torch::special_sinc_out(result, self); +} + +/// Rounds the elements of the input +/// See https://pytorch.org/docs/main/special.html#torch.special.round. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::round(t); +/// ``` +inline Tensor round(const Tensor& self) { + return torch::special_round(self); +} + +inline Tensor& round_out(Tensor& result, const Tensor& self) { + return torch::special_round_out(result, self); +} + +/// Computes log(1 + x) of the input, elementwise +/// See https://pytorch.org/docs/main/special.html#torch.special.log1p. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::log1p(t); +/// ``` +inline Tensor log1p(const Tensor& self) { + return torch::special_log1p(self); +} + +inline Tensor& log1p_out(Tensor& result, const Tensor& self) { + return torch::special_log1p_out(result, self); +} + +/// Computes log followed by softmax(x) of the input +/// See https://pytorch.org/docs/main/special.html#torch.special.log_softmax. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, 128, dtype=kDouble); +/// torch::special::log_softmax(t, 0); +/// ``` +inline Tensor log_softmax( + const Tensor& self, + int64_t dim, + std::optional dtype) { + return torch::special_log_softmax(self, dim, dtype); +} + +/// Computes softmax of the input along a given dimension +/// See https://pytorch.org/docs/main/special.html#torch.special.softmax. +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, 128, dtype=kDouble); +/// torch::special::softmax(t, 0); +/// ``` +inline Tensor softmax( + const Tensor& self, + int64_t dim, + std::optional dtype) { + return torch::special_softmax(self, dim, dtype); +} + +/// Airy function Ai. +/// +/// See https://pytorch.org/docs/main/special.html#torch.special.airy_ai. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::airy_ai(x); +/// ``` +inline Tensor airy_ai(const Tensor& x) { + return torch::special_airy_ai(x); +} + +inline Tensor& airy_ai_out(Tensor& y, const Tensor& x) { + return torch::special_airy_ai_out(y, x); +} + +/// Bessel function of the first kind of order 0. +/// +/// See https://pytorch.org/docs/main/special.html#torch.special.bessel_j0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::bessel_j0(x); +/// ``` +inline Tensor bessel_j0(const Tensor& self) { + return torch::special_bessel_j0(self); +} + +inline Tensor& bessel_j0_out(Tensor& result, const Tensor& self) { + return torch::special_bessel_j0_out(result, self); +} + +/// Bessel function of the first kind of order 1. +/// +/// See https://pytorch.org/docs/main/special.html#torch.special.bessel_j1. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::bessel_j1(x); +/// ``` +inline Tensor bessel_j1(const Tensor& self) { + return torch::special_bessel_j1(self); +} + +inline Tensor& bessel_j1_out(Tensor& result, const Tensor& self) { + return torch::special_bessel_j1_out(result, self); +} + +/// Bessel function of the second kind of order 0. +/// +/// See https://pytorch.org/docs/main/special.html#torch.special.bessel_y0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::bessel_y0(x); +/// ``` +inline Tensor bessel_y0(const Tensor& self) { + return torch::special_bessel_y0(self); +} + +inline Tensor& bessel_y0_out(Tensor& result, const Tensor& self) { + return torch::special_bessel_y0_out(result, self); +} + +/// Bessel function of the second kind of order 1. +/// +/// See https://pytorch.org/docs/main/special.html#torch.special.bessel_y1. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::bessel_y1(x); +/// ``` +inline Tensor bessel_y1(const Tensor& self) { + return torch::special_bessel_y1(self); +} + +inline Tensor& bessel_y1_out(Tensor& result, const Tensor& self) { + return torch::special_bessel_y1_out(result, self); +} + +/// Chebyshev polynomial of the first kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.chebyshev_polynomial_t. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::chebyshev_polynomial_t(x, n); +/// ``` +inline Tensor chebyshev_polynomial_t(const Tensor& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_t(x, n); +} + +inline Tensor chebyshev_polynomial_t(const Scalar& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_t(x, n); +} + +inline Tensor chebyshev_polynomial_t(const Tensor& x, const Scalar& n) { + return torch::special_chebyshev_polynomial_t(x, n); +} + +inline Tensor& chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_t_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_t_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_t_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_chebyshev_polynomial_t_out(output, x, n); +} + +/// Chebyshev polynomial of the second kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.chebyshev_polynomial_u. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::chebyshev_polynomial_u(x, n); +/// ``` +inline Tensor chebyshev_polynomial_u(const Tensor& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_u(x, n); +} + +inline Tensor chebyshev_polynomial_u(const Scalar& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_u(x, n); +} + +inline Tensor chebyshev_polynomial_u(const Tensor& x, const Scalar& n) { + return torch::special_chebyshev_polynomial_u(x, n); +} + +inline Tensor& chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_u_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_u_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_u_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_chebyshev_polynomial_u_out(output, x, n); +} + +/// Chebyshev polynomial of the third kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.chebyshev_polynomial_v. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::chebyshev_polynomial_v(x, n); +/// ``` +inline Tensor chebyshev_polynomial_v(const Tensor& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_v(x, n); +} + +inline Tensor chebyshev_polynomial_v(const Scalar& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_v(x, n); +} + +inline Tensor chebyshev_polynomial_v(const Tensor& x, const Scalar& n) { + return torch::special_chebyshev_polynomial_v(x, n); +} + +inline Tensor& chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_v_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_v_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_v_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_chebyshev_polynomial_v_out(output, x, n); +} + +/// Chebyshev polynomial of the fourth kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.chebyshev_polynomial_w. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::chebyshev_polynomial_w(x, n); +/// ``` +inline Tensor chebyshev_polynomial_w(const Tensor& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_w(x, n); +} + +inline Tensor chebyshev_polynomial_w(const Scalar& x, const Tensor& n) { + return torch::special_chebyshev_polynomial_w(x, n); +} + +inline Tensor chebyshev_polynomial_w(const Tensor& x, const Scalar& n) { + return torch::special_chebyshev_polynomial_w(x, n); +} + +inline Tensor& chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_w_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_w_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_chebyshev_polynomial_w_out(output, x, n); +} + +inline Tensor& chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_chebyshev_polynomial_w_out(output, x, n); +} + +/// Physicist’s Hermite polynomial. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.hermite_polynomial_h. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::hermite_polynomial_h(x, n); +/// ``` +inline Tensor hermite_polynomial_h(const Tensor& x, const Tensor& n) { + return torch::special_hermite_polynomial_h(x, n); +} + +inline Tensor hermite_polynomial_h(const Scalar& x, const Tensor& n) { + return torch::special_hermite_polynomial_h(x, n); +} + +inline Tensor hermite_polynomial_h(const Tensor& x, const Scalar& n) { + return torch::special_hermite_polynomial_h(x, n); +} + +inline Tensor& hermite_polynomial_h_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_hermite_polynomial_h_out(output, x, n); +} + +inline Tensor& hermite_polynomial_h_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_hermite_polynomial_h_out(output, x, n); +} + +inline Tensor& hermite_polynomial_h_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_hermite_polynomial_h_out(output, x, n); +} + +/// Probabilist’s Hermite polynomial. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.hermite_polynomial_he. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::hermite_polynomial_he(x, n); +/// ``` +inline Tensor hermite_polynomial_he(const Tensor& x, const Tensor& n) { + return torch::special_hermite_polynomial_he(x, n); +} + +inline Tensor hermite_polynomial_he(const Scalar& x, const Tensor& n) { + return torch::special_hermite_polynomial_he(x, n); +} + +inline Tensor hermite_polynomial_he(const Tensor& x, const Scalar& n) { + return torch::special_hermite_polynomial_he(x, n); +} + +inline Tensor& hermite_polynomial_he_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_hermite_polynomial_he_out(output, x, n); +} + +inline Tensor& hermite_polynomial_he_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_hermite_polynomial_he_out(output, x, n); +} + +inline Tensor& hermite_polynomial_he_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_hermite_polynomial_he_out(output, x, n); +} + +/// Laguerre polynomial. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.laguerre_polynomial_l. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::laguerre_polynomial_l(x, n); +/// ``` +inline Tensor laguerre_polynomial_l(const Tensor& x, const Tensor& n) { + return torch::special_laguerre_polynomial_l(x, n); +} + +inline Tensor laguerre_polynomial_l(const Scalar& x, const Tensor& n) { + return torch::special_laguerre_polynomial_l(x, n); +} + +inline Tensor laguerre_polynomial_l(const Tensor& x, const Scalar& n) { + return torch::special_laguerre_polynomial_l(x, n); +} + +inline Tensor& laguerre_polynomial_l_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_laguerre_polynomial_l_out(output, x, n); +} + +inline Tensor& laguerre_polynomial_l_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_laguerre_polynomial_l_out(output, x, n); +} + +inline Tensor& laguerre_polynomial_l_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_laguerre_polynomial_l_out(output, x, n); +} + +/// Legendre polynomial. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.legendre_polynomial_p. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::legendre_polynomial_p(x, n); +/// ``` +inline Tensor legendre_polynomial_p(const Tensor& x, const Tensor& n) { + return torch::special_legendre_polynomial_p(x, n); +} + +inline Tensor legendre_polynomial_p(const Scalar& x, const Tensor& n) { + return torch::special_legendre_polynomial_p(x, n); +} + +inline Tensor legendre_polynomial_p(const Tensor& x, const Scalar& n) { + return torch::special_legendre_polynomial_p(x, n); +} + +inline Tensor& legendre_polynomial_p_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_legendre_polynomial_p_out(output, x, n); +} + +inline Tensor& legendre_polynomial_p_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_legendre_polynomial_p_out(output, x, n); +} + +inline Tensor& legendre_polynomial_p_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_legendre_polynomial_p_out(output, x, n); +} + +/// Modified Bessel function of the first kind of order 0. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.modified_bessel_i0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::modified_bessel_i0(x); +/// ``` +inline Tensor modified_bessel_i0(const Tensor& self) { + return torch::special_modified_bessel_i0(self); +} + +inline Tensor& modified_bessel_i0_out(Tensor& result, const Tensor& self) { + return torch::special_modified_bessel_i0_out(result, self); +} + +/// Modified Bessel function of the first kind of order 1. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.modified_bessel_i1. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::modified_bessel_i1(x); +/// ``` +inline Tensor modified_bessel_i1(const Tensor& self) { + return torch::special_modified_bessel_i1(self); +} + +inline Tensor& modified_bessel_i1_out(Tensor& result, const Tensor& self) { + return torch::special_modified_bessel_i1_out(result, self); +} + +/// Modified Bessel function of the second kind of order 0. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.modified_bessel_k0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::modified_bessel_k0(x); +/// ``` +inline Tensor modified_bessel_k0(const Tensor& self) { + return torch::special_modified_bessel_k0(self); +} + +inline Tensor& modified_bessel_k0_out(Tensor& result, const Tensor& self) { + return torch::special_modified_bessel_k0_out(result, self); +} + +/// Modified Bessel function of the second kind of order 1. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.modified_bessel_k1. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::modified_bessel_k1(x); +/// ``` +inline Tensor modified_bessel_k1(const Tensor& self) { + return torch::special_modified_bessel_k1(self); +} + +inline Tensor& modified_bessel_k1_out(Tensor& result, const Tensor& self) { + return torch::special_modified_bessel_k1_out(result, self); +} + +/// Scaled modified Bessel function of the second kind of order 0. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.scaled_modified_bessel_k0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::scaled_modified_bessel_k0(x); +/// ``` +inline Tensor scaled_modified_bessel_k0(const Tensor& x) { + return torch::special_scaled_modified_bessel_k0(x); +} + +inline Tensor& scaled_modified_bessel_k0_out(Tensor& y, const Tensor& x) { + return torch::special_scaled_modified_bessel_k0_out(y, x); +} + +/// Scaled modified Bessel function of the second kind of order 1. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.scaled_modified_bessel_k1. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::scaled_modified_bessel_k1(x); +/// ``` +inline Tensor scaled_modified_bessel_k1(const Tensor& x) { + return torch::special_scaled_modified_bessel_k1(x); +} + +inline Tensor& scaled_modified_bessel_k1_out(Tensor& y, const Tensor& x) { + return torch::special_scaled_modified_bessel_k1_out(y, x); +} + +/// Shifted Chebyshev polynomial of the first kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.shifted_chebyshev_polynomial_t. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::shifted_chebyshev_polynomial_t(x, n); +/// ``` +inline Tensor shifted_chebyshev_polynomial_t(const Tensor& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_t(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_t(const Scalar& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_t(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_t(const Tensor& x, const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_t(x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_t_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_t_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_t_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_t_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_t_out(output, x, n); +} + +/// Shifted Chebyshev polynomial of the second kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.shifted_chebyshev_polynomial_u. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::shifted_chebyshev_polynomial_u(x, n); +/// ``` +inline Tensor shifted_chebyshev_polynomial_u(const Tensor& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_u(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_u(const Scalar& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_u(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_u(const Tensor& x, const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_u(x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_u_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_u_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_u_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_u_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_u_out(output, x, n); +} + +/// Shifted Chebyshev polynomial of the third kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.shifted_chebyshev_polynomial_v. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::shifted_chebyshev_polynomial_v(x, n); +/// ``` +inline Tensor shifted_chebyshev_polynomial_v(const Tensor& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_v(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_v(const Scalar& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_v(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_v(const Tensor& x, const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_v(x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_v_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_v_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_v_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_v_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_v_out(output, x, n); +} + +/// Shifted Chebyshev polynomial of the fourth kind. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.shifted_chebyshev_polynomial_w. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// auto n = torch::randn(128, dtype=kDouble); +/// +/// torch::special::shifted_chebyshev_polynomial_w(x, n); +/// ``` +inline Tensor shifted_chebyshev_polynomial_w(const Tensor& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_w(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_w(const Scalar& x, const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_w(x, n); +} + +inline Tensor shifted_chebyshev_polynomial_w(const Tensor& x, const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_w(x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_w_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_w_out( + Tensor& output, + const Scalar& x, + const Tensor& n) { + return torch::special_shifted_chebyshev_polynomial_w_out(output, x, n); +} + +inline Tensor& shifted_chebyshev_polynomial_w_out( + Tensor& output, + const Tensor& x, + const Scalar& n) { + return torch::special_shifted_chebyshev_polynomial_w_out(output, x, n); +} + +/// Spherical Bessel function of the first kind of order 0. +/// +/// See +/// https://pytorch.org/docs/main/special.html#torch.special.spherical_bessel_j0. +/// +/// Example: +/// +/// ``` +/// auto x = torch::randn(128, dtype=kDouble); +/// +/// torch::special::spherical_bessel_j0(x); +/// ``` +inline Tensor spherical_bessel_j0(const Tensor& x) { + return torch::special_spherical_bessel_j0(x); +} + +inline Tensor& spherical_bessel_j0_out(Tensor& y, const Tensor& x) { + return torch::special_spherical_bessel_j0_out(y, x); +} +} // namespace torch::special diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/torch.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/torch.h new file mode 100644 index 0000000000000000000000000000000000000000..7316af88d2eba7337086b29d099370bf30aa99dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/torch.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +#ifdef TORCH_API_INCLUDE_EXTENSION_H +#include + +#endif // defined(TORCH_API_INCLUDE_EXTENSION_H) diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/types.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/types.h new file mode 100644 index 0000000000000000000000000000000000000000..c00e676efe135b71af24200d80e2f635bf1352e1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/types.h @@ -0,0 +1,70 @@ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace torch { + +// NOTE [ Exposing declarations in `at::` to `torch::` ] +// +// The following line `using namespace at;` is responsible for exposing all +// declarations in `at::` namespace to `torch::` namespace. +// +// According to the rules laid out in +// https://en.cppreference.com/w/cpp/language/qualified_lookup, section +// "Namespace members": +// ``` +// Qualified lookup within the scope of a namespace N first considers all +// declarations that are located in N and all declarations that are located in +// the inline namespace members of N (and, transitively, in their inline +// namespace members). If there are no declarations in that set then it +// considers declarations in all namespaces named by using-directives found in N +// and in all transitive inline namespace members of N. +// ``` +// +// This means that if both `at::` and `torch::` namespaces have a function with +// the same signature (e.g. both `at::func()` and `torch::func()` exist), after +// `namespace torch { using namespace at; }`, when we call `torch::func()`, the +// `func()` function defined in `torch::` namespace will always be called, and +// the `func()` function defined in `at::` namespace is always hidden. +using namespace at; // NOLINT + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +using std::nullopt; // NOLINT +using std::optional; // NOLINT +#endif + +using Dtype = at::ScalarType; + +/// Fixed width dtypes. +constexpr auto kUInt8 = at::kByte; +constexpr auto kInt8 = at::kChar; +constexpr auto kInt16 = at::kShort; +constexpr auto kInt32 = at::kInt; +constexpr auto kInt64 = at::kLong; +constexpr auto kUInt16 = at::kUInt16; +constexpr auto kUInt32 = at::kUInt32; +constexpr auto kUInt64 = at::kUInt64; +constexpr auto kFloat16 = at::kHalf; +constexpr auto kFloat32 = at::kFloat; +constexpr auto kFloat64 = at::kDouble; + +/// Rust-style short dtypes. +constexpr auto kU8 = kUInt8; +constexpr auto kU16 = kUInt16; +constexpr auto kU32 = kUInt32; +constexpr auto kU64 = kUInt64; +constexpr auto kI8 = kInt8; +constexpr auto kI16 = kInt16; +constexpr auto kI32 = kInt32; +constexpr auto kI64 = kInt64; +constexpr auto kF16 = kFloat16; +constexpr auto kF32 = kFloat32; +constexpr auto kF64 = kFloat64; +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/utils.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a517043fa3ff881072eb5a232740e7d304f51a1a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/utils.h @@ -0,0 +1,117 @@ +#pragma once + +#include +#include +#include +#include +#include + +// NOLINTBEGIN(misc-unused-using-decls) +namespace torch { + +/// A RAII, thread-local guard that disabled gradient calculation. +/// +/// Disabling gradient calculation is useful for inference, when you are sure +/// that you will not call `at::Tensor::backward`. It will reduce memory +/// consumption for computations that would otherwise have `requires_grad() == +/// true`. +/// +/// In this mode, the result of every computation will have +/// `requires_grad() == false`, even when the inputs have `requires_grad() == +/// true`. +/// +/// This context manager is thread-local; it will not affect computation +/// in other threads. +/// +/// Example: +/// @code +/// auto x = torch::tensor({1.}, torch::requires_grad()); +/// { +/// torch::NoGradGuard no_grad; +/// auto y = x * 2; +/// std::cout << y.requires_grad() << std::endl; // prints `false` +/// } +/// { +/// auto doubler = [](torch::Tensor x) { +/// torch::NoGradGuard no_grad; +/// return x * 2; +/// }; +/// auto z = doubler(x); +/// std::cout << z.requires_grad() << std::endl; // prints `false` +/// } +/// @endcode +using NoGradGuard = at::NoGradGuard; + +/// A RAII, thread-local guard that sets gradient calculation to on or off. +/// +/// ``AutoGradMode`` will enable or disable grads based on its argument +/// `enabled`. +/// +/// This context manager is thread-local; it will not affect computation +/// in other threads. +/// +/// \param enabled: Flag whether to enable grad (``true``), or disable +/// (``false``). This can be used to conditionally enable +/// gradients. +/// +/// Example: +/// @code +/// auto x = torch::tensor({1.}, torch::requires_grad()); +/// { +/// torch::AutoGradMode enable_grad(true); +/// auto y = x * 2; +/// std::cout << y.requires_grad() << std::endl; // prints `true` +/// } +/// { +/// torch::AutoGradMode enable_grad(false); +/// auto y = x * 2; +/// std::cout << y.requires_grad() << std::endl; // prints `false` +/// } +/// @endcode +using AutoGradMode = at::AutoGradMode; + +/// Sets the global random seed for all newly created CPU and CUDA tensors. +using at::manual_seed; + +// Called during new thread initialization +using at::init_num_threads; + +// Returns the number of threads used in parallel region. +using at::get_num_threads; + +// Sets the number of threads to be used in parallel region. +using at::set_num_threads; + +// Returns the number of threads used for inter-op parallelism. +using at::get_num_interop_threads; + +// Sets the number of threads to be used for inter-op parallelism. +using at::set_num_interop_threads; + +// Returns true if both t1, t2 are undefined or both are defined and equal +inline bool equal_if_defined(const Tensor& t1, const Tensor& t2) { + return ( + (!t1.defined() && !t2.defined()) || + (t1.defined() && t2.defined() && torch::equal(t1, t2))); +} + +// RecordFunction API +using at::addGlobalCallback; +using at::addThreadLocalCallback; +using at::CallbackHandle; +using at::clearCallbacks; +using at::clearGlobalCallbacks; +using at::clearThreadLocalCallbacks; +using at::DisableRecordFunctionGuard; +using at::enableRecordFunction; +using at::hasCallbacks; +using at::hasGlobalCallbacks; +using at::hasThreadLocalCallbacks; +using at::isRecordFunctionEnabled; +using at::RecordFunction; +using at::RecordFunctionCallback; +using at::RecordFunctionGuard; +using at::removeCallback; + +} // namespace torch +// NOLINTEND(misc-unused-using-decls) diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/version.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/version.h new file mode 100644 index 0000000000000000000000000000000000000000..2812814f776cce006fbc3c3d8a153943f053656d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/version.h @@ -0,0 +1,26 @@ +#pragma once + +/// Indicates the major version of LibTorch. +#define TORCH_VERSION_MAJOR 2 + +/// Indicates the minor version of LibTorch. +#define TORCH_VERSION_MINOR 8 + +/// Indicates the patch version of LibTorch. +#define TORCH_VERSION_PATCH 0 + +/// Indicates the ABI version tag of LibTorch. +#define TORCH_VERSION_ABI_TAG 0 + +/// Indicates the version of LibTorch as a string literal. +#define TORCH_VERSION \ + "2.8.0" + +/// Indicates the ABI version of LibTorch as a single uint64. +/// [ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ] +/// [ MAJ ][ MIN ][ PATCH][ ABI TAG ] +#define TORCH_ABI_VERSION \ + (uint64_t)TORCH_VERSION_MAJOR << 56 | \ + (uint64_t)TORCH_VERSION_MINOR << 48 | \ + (uint64_t)TORCH_VERSION_PATCH << 40 | \ + TORCH_VERSION_ABI_TAG << 0 diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..76eed364aaf9f4aff91ebf99ddb09a71cdf0d455 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/api/include/torch/xpu.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include +#include + +namespace torch::xpu { + +/// Returns the number of XPU devices available. +size_t TORCH_API device_count(); + +/// Returns true if at least one XPU device is available. +bool TORCH_API is_available(); + +/// Sets the seed for the current GPU. +void TORCH_API manual_seed(uint64_t seed); + +/// Sets the seed for all available GPUs. +void TORCH_API manual_seed_all(uint64_t seed); + +/// Waits for all kernels in all streams on a XPU device to complete. +void TORCH_API synchronize(int64_t device_index); + +} // namespace torch::xpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/grad_layout_contract.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/grad_layout_contract.h new file mode 100644 index 0000000000000000000000000000000000000000..ed97dc4530eb418cdbc84bc456273445be62cae7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/grad_layout_contract.h @@ -0,0 +1,76 @@ +#pragma once + +#include + +namespace torch::autograd::utils { + +// Helper functions to enforce the "Gradient Layout Contract" described in +// torch/csrc/autograd/functions/accumulate_grad.h. + +// Checks if grad obeys the contract with variable. +inline bool obeys_layout_contract( + const at::Tensor& grad, + const at::Tensor& variable) { + TORCH_INTERNAL_ASSERT(!grad.is_sparse()); + TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr()); + TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr()); + + // NOLINTNEXTLINE(bugprone-branch-clone) + if (variable.is_nested()) { + // TODO: Nested Tensor does not have an implementation of detach. The + // current implementation of nested tensor likely does obey the gradient + // contract and should return true, but this would likely change in the + // future + return false; + } else if (variable.is_sparse()) { + // Gradient Layout Contract is not applicable for sparse layouts + return false; + } else if (variable.is_non_overlapping_and_dense()) { + // Only look at stride for dimensions that are not of size 1. + const auto& grad_sizes = grad.sym_sizes(); + const auto& grad_strides = grad.sym_strides(); + const auto& variable_strides = variable.sym_strides(); + for (const auto idx : c10::irange(grad_sizes.size())) { + if (grad_sizes[idx] != 1) { + if (grad_strides[idx] != variable_strides[idx]) { + return false; + } + } else { + // This should not be needed but we don't check if a Tensor has views + // before stashing it. And 0-strided Tensors of size 1 are actually + // views for ops like cat. + // TODO: Actually detect views in the accumulateGrad function so that + // this Tensor is not considered at all. + if (grad_strides[idx] == 0) { + return false; + } + } + } + return true; + } else { + return grad.is_contiguous(at::MemoryFormat::Contiguous); + } +} + +// Creates a clone of new_grad that obeys the contract with variable. +// The clone should attach to new_grad's history if GradMode::is_enabled(). +inline at::Tensor clone_obey_contract( + const at::Tensor& new_grad, + const at::Tensor& variable) { + if (variable.is_non_overlapping_and_dense()) { + // (1) + // Does this dicey-looking sequence attach the result to new_grad's + // history if GradMode::is_enabled()? Yes, and @alband says it should. + return std::move(new_grad + .new_empty_strided_symint( + variable.sym_sizes(), + variable.sym_strides(), + variable.options().memory_format(std::nullopt)) + .copy_(new_grad)); + } else { + // (2) + return new_grad.clone(at::MemoryFormat::Contiguous); + } +} + +} // namespace torch::autograd::utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/lambda_post_hook.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/lambda_post_hook.h new file mode 100644 index 0000000000000000000000000000000000000000..e43d7a23876df5baaf63be088728673a0265f1fb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/lambda_post_hook.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace torch::autograd::utils { + +// Turns lambda into a torch::autograd::FunctionPostHook. +class LambdaPostHook : public torch::autograd::FunctionPostHook { + using variable_list = std::vector; + using fn_type = + std::function; + using compiled_fn_type = std::function; + + public: + // The lambda function takes as arguments the outputs and inputs of the + // autograd function and can modify the outputs of the autograd function by + // returning a new output if needed. + /* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {} + + LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn) + : fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {} + + variable_list operator()( + const variable_list& outputs, + const variable_list& inputs) override { + return fn_(outputs, inputs); + } + + void compiled_args(CompiledNodeArgs& args) const override { + if (compiled_fn_ != nullptr) { + return compiled_fn_(args); + } + return FunctionPostHook::compiled_args(args); + } + + protected: + std::function fn_; + compiled_fn_type compiled_fn_{}; +}; + +} // namespace torch::autograd::utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/python_arg_parsing.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/python_arg_parsing.h new file mode 100644 index 0000000000000000000000000000000000000000..9d4ec8dfcd3af072f6375093d2a2dd4182af4d86 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/python_arg_parsing.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include + +#include + +namespace torch::autograd::utils { + +// The parameter allow_copy is to accept copy for Tensor.to (and by proxy +// PackedSequences.to) but not nn.Module.to. +inline std::tuple< + std::optional, + std::optional, + bool, + bool, + std::optional> +parse_to_conversion(PythonArgs& r, bool allow_copy) { + if (r.idx == 0) { + if (!allow_copy && !r.isNone(3)) + throw std::runtime_error(".to() does not accept copy argument"); + return std::make_tuple( + r.deviceOptional(0), + r.scalartypeOptional(1), + r.toBool(2), + r.toBool(3), + r.memoryformatOptional(4)); + } else if (r.idx == 1) { + if (!allow_copy && !r.isNone(2)) + throw std::runtime_error(".to() does not accept copy argument"); + return std::make_tuple( + std::nullopt, + r.scalartype(0), + r.toBool(1), + r.toBool(2), + r.memoryformatOptional(3)); + } else { + auto tensor = r.tensor(0); + if (!allow_copy && !r.isNone(2)) + throw std::runtime_error(".to() does not accept copy argument"); + return std::make_tuple( + tensor.device(), + tensor.scalar_type(), + r.toBool(1), + r.toBool(2), + r.memoryformatOptional(3)); + } +} +} // namespace torch::autograd::utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/warnings.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/warnings.h new file mode 100644 index 0000000000000000000000000000000000000000..ced5663ef4ed1da825d84bf022e01caf94ca9d2d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/autograd/utils/warnings.h @@ -0,0 +1,24 @@ +#pragma once +#include + +#include +#include + +namespace torch::autograd::utils { + +// Warning handler for multi-threaded contexts. Gather warnings from +// all threads into a single queue, then process together at the end +// in the main thread. +class DelayWarningHandler : public at::WarningHandler { + public: + ~DelayWarningHandler() override = default; + void replay_warnings(); + + private: + void process(const c10::Warning& warning) override; + + std::vector warnings_; + std::mutex mutex_; +}; + +} // namespace torch::autograd::utils diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_eager/kernel_holder.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_eager/kernel_holder.h new file mode 100644 index 0000000000000000000000000000000000000000..8459b35c683730e6d2ed9f153e1a07d96fb25442 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_eager/kernel_holder.h @@ -0,0 +1,112 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include + +namespace torch::inductor { + +// Represent AOTI kernel. It contains all the parameter metadata of the kernel +// and the AOTI model runner. +struct AOTIKernelMetadata { + // Represent all the parameters of AOTI kernel + std::vector parameter_metadata_list_; + // AOTI model runner to run the AOTI kernel + std::shared_ptr kernel_runner_; + AOTIKernelMetadata() : kernel_runner_(nullptr) {} + + // Check whether the given parameter metadata list is the same as the + // parameter metadata list of the AOTI kernel. + bool check( + const std::vector& parameter_metadata_list) const { + if (parameter_metadata_list_.size() != parameter_metadata_list.size()) { + return false; + } + + for (size_t i = 0; i < parameter_metadata_list_.size(); ++i) { + if (parameter_metadata_list_[i] == parameter_metadata_list[i]) { + continue; + } else { + return false; + } + } + + return true; + } +}; + +// The AOTIPythonKernelHolder class uses the AOT Inductor to generate a kernel +// for a specified operation. To speed up this process, the generated kernel +// library is cached on disk. Detailed information from the input tensors is +// used as the key for caching the kernel library. On subsequent runs, these +// input tensors are used to search the cache. If a cache hit occurs, the cached +// kernel library is loaded and executed. If a cache miss occurs, the AOT +// Inductor is called again to generate the kernel library. +class AOTIPythonKernelHolder : public c10::OperatorKernel { + // A DispatchKey object that represents the dispatch key for the kernel. + c10::DispatchKey dispatch_key_; + // Namespace of the kernel. + std::string ns_; + // Name of the operation the kernel performs. + std::string op_name_with_overload_; + // The device on which the kernel is to be executed. + c10::Device device_; + // The Python interpreter to get OpOverload object with the given op_name and + // op_overload_name. + c10::impl::PyInterpreter* pyinterpreter_; + // Cache the produced kernels by AOTI and its metadata + std::vector aoti_kernel_cache_; + + public: + AOTIPythonKernelHolder( + c10::DispatchKey dispatch_key, + std::string_view ns, + std::string_view op_name_with_overload); + + void operator()( + const c10::OperatorHandle& op, + c10::DispatchKeySet keyset, + torch::jit::Stack* stack); + + private: + bool cache_lookup( + const c10::OperatorHandle& op, + const c10::DispatchKeySet& keyset, + const torch::jit::Stack* stack, + AOTIKernelMetadata& aoti_kernel_metadata); + void cache_miss( + const c10::OperatorHandle& op, + const c10::DispatchKeySet& keyset, + torch::jit::Stack* stack); + void cache_hit( + const AOTIKernelMetadata& aoti_kernel_metadata, + const c10::OperatorHandle& op, + const c10::DispatchKeySet& keyset, + torch::jit::Stack* stack); + // Invoke python utility function on the Inductor side to produce AOTI kernel + // for the given operation. + // Inductor utility function - + // torch._inductor.utils.aoti_compile_with_persistent_cache + std::string produce_aoti_kernel_lib( + const c10::OperatorHandle& op, + const c10::DispatchKeySet& keyset, + const torch::jit::Stack* stack); + // Invoke python utility function on the Inductor side to load AOTI kernel for + // the given operation. + // Inductor utility function - torch._inductor.utils.load_aoti_eager_cache + void init_aoti_kernel_cache(); + // Load the AOTIModelContainerRunner object from the given file path. + std::shared_ptr load_aoti_model_runner( + const std::string&); +}; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_eager/kernel_meta_info.h new file mode 100644 index 0000000000000000000000000000000000000000..24d3c05bc3505c036465f71161e5e91d81b87856 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_eager/kernel_meta_info.h @@ -0,0 +1,142 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include +#include + +#include + +namespace torch::inductor { + +// Regarding a aten operation implemented by AOTI, the metadata of the input +// tensors will be cached on the disk to accelerate next run. TensorMetada +// structure is to represent the metadata of each input tensor. It includes +// whether the tensor is symbolic, the dtype, the device, the sizes and the +// strides of the tensor. When the metadata of the input tensors is the same as +// the cached metadata, the cached kernel library will be loaded and executed. +// Otherwise, the AOT Inductor will be called again to generate the kernel +// library. +// Beyond the TensorMetadata, we build guard/TensorCheck for each input tensor +// as well to support symbolic shape. We intend to utilize TensorCheck to find +// out the proper kernel rather than TensorMetada comparison. Suppose an +// operation with a single input tensor and two kernels: +// kernel1: TensorMetadata(is_symbolic=false, dtype=Float, device=CPU, +// sizes=[s0, s1, s2], strides=[s1 * s2, s2, 1]) kernel2: +// TensorMetadata(is_symbolic=false, dtype=Float, device=CPU, sizes=[3, s1, +// s2], strides=[s1 * s2, s2, 1]) +// If a tensor with sizes=[3, 4, 5] is passed to the operation, both kernel1 and +// kernel2 support the tensor shape. In this case, we need to use TensorCheck +// plus some heruistic rules to find out the proper kernel. +struct TensorMetadata { + // Indicate whether the tensor is symbolic and it may be concluded by sizes_ + // and strides_ in the future. + bool is_symbolic_; + // Dtype of a tensor(For scalar, we will wrap it as a scalar tensor) + c10::ScalarType dtype_ = c10::ScalarType::Undefined; + // Device of a tensor. + c10::Device device_; + // Dispatch key set of a tensor + c10::DispatchKeySet dispatch_key_set_; + // Sizes of a tensor. Currently, we only support static shape and use int64_t + // to represent the sizes. In the future, we will create symbolic size and use + // SymInt to represent it to support symbolic shape. + std::vector sizes_; + // Strides of a tensor. For symbolic shape support, it is the same as sizes_ + std::vector strides_; + // requires grad + bool requires_grad_ = false; + // TensorCheck for the tensor + std::optional tensor_check_; + + TensorMetadata() + : is_symbolic_(false), + device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES), + sizes_({}), + strides_({}) {} + TensorMetadata(const at::Tensor& src_tensor); + TensorMetadata( + bool is_symbolic, + c10::ScalarType dtype, + c10::Device device, + c10::DispatchKeySet dispatch_key_set, + std::vector sizes, + std::vector strides, + bool requires_grad = false); + + // Build TensorCheck for the tensor by using the data fields in TensorMetadata + void build_guard(const dynamo::LocalState& local_state); + + // Compare two TensorMetadata objects + bool operator==(const TensorMetadata& other) const; +}; + +// ParameterTag is to represent the type of the input parameters of a aten +// operation. Currently, we support the following types: +// 1. TENSOR: a single tensor +// 2. TENSOR_OPTIONAL: a single optional tensor +// 3. TENSOR_LIST: a list of tensors +// 4. TENSOR_LIST_OPTIONAL: a list of optional tensors +// 5. SCALAR: a scalar value +// If we need to support more types in the future, we will add more types in the +// ParameterTag enum. For example, we will extend the enum to support string, +// Dimname and so on to support more types of input parameters of aten +// operations. +enum ParameterTag { + TENSOR, + TENSOR_OPTIONAL, + TENSOR_LIST, + TENSOR_LIST_OPTIONAL, + SCALAR, + STRING, + DEVICE, + INVALID, +}; + +// ParameterMetadataValue is to represent the value of the input parameters of a +// aten operation. +using ParameterMetadataValue = std::variant< + TensorMetadata, + std::vector, + c10::Scalar, + std::string, + c10::Device>; + +// ParameterMetadata is to represent the metadata of the input parameters of a +// aten operation. It includes the tag of the parameter, the value of the +// parameter and the order of the parameter. +struct ParameterMetadata { + // The tag of the parameter. It indicates the type of the parameter. + ParameterTag tag_; + // The value of the parameter. It can be a tensor, a list of tensors or a + // scalar. + ParameterMetadataValue value_; + // The order of the parameter is used to distinguish the parameters with the + // same tag. For example, an operation with two input tensors, the first + // tensor is a optional tensor and the second tensor is a tensor. The first + // tensor will have the order 0 and the second tensor will have the order 1. + uint64_t order_{}; + + ParameterMetadata() : tag_(INVALID) {} + ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order); + ParameterMetadata(const at::Tensor& tensor, uint64_t input_order); + ParameterMetadata( + const std::vector& tensor_list, + uint64_t input_order); + ParameterMetadata( + const std::vector& tensor_metadata_list, + uint64_t input_order); + ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order); + ParameterMetadata(const std::string& string_value, uint64_t input_order); + ParameterMetadata(const c10::Device& device, uint64_t input_order); + + bool operator==(const ParameterMetadata& other) const; + + private: + // Helper function to compare two ParameterMetadata objects with the same + // SCALAR tag. + bool equal_to(const c10::Scalar& scalar) const; +}; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/array_ref.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/array_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..35b7e168e69384caaeaef92685adc2421f2a6c5b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/array_ref.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/common.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/common.h new file mode 100644 index 0000000000000000000000000000000000000000..e0e61ac0615d74c9e4417b63b04d3c29a6b19515 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/common.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include + +// Round up to the nearest multiple of 64 +[[maybe_unused]] inline int64_t align(int64_t nbytes) { + return (nbytes + 64 - 1) & -64; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/cpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..508a15b45635ef4fd4f127c86b48190398a5c79d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/cpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..59948abf1714b47722b5cb3c7ceb2b07c21349a6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/cuda.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/mps.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/mps.h new file mode 100644 index 0000000000000000000000000000000000000000..a96ea0f7eed71dc17f13c3a15f4c8e25e689c25f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/mps.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..d0e15b13f11f3683d8330a4943fb5891448275e3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_include/xpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_package/model_package_loader.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_package/model_package_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..db990de26d3bdec15786cccd41ca973d0f1a53dc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_package/model_package_loader.h @@ -0,0 +1,54 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include +#include + +namespace torch::inductor { +class TORCH_API AOTIModelPackageLoader { + public: + AOTIModelPackageLoader( + const std::string& model_package_path, + const std::string& model_name = "model", + const bool run_single_threaded = false, + const size_t num_runners = 1, + const c10::DeviceIndex device_index = -1); + ~AOTIModelPackageLoader(); + + AOTIModelContainerRunner* get_runner(); + std::unordered_map get_metadata(); + + std::vector run( + const std::vector& inputs, + void* stream_handle = nullptr); + + // boxed_run will steal the ownership of the input tensors + std::vector boxed_run( + std::vector&& inputs, + void* stream_handle = nullptr); + + std::vector get_call_spec(); + void load_constants( + std::unordered_map& constants_map, + bool use_inactive, + bool check_full_update, + bool user_managed = false); + std::vector get_constant_fqns(); + + void update_constant_buffer( + std::unordered_map& tensor_map, + bool use_inactive, + bool validate_full_updates, + bool user_managed = false); + + private: + std::string temp_dir_; + std::unique_ptr runner_; + std::unordered_map metadata_; + + void load_metadata(const std::string& cpp_filename); +}; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_package/pybind.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_package/pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..1eb7818c00e906b9069f64298b3901450bde88c1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_package/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject* module); + +} // namespace torch::inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..39065dab187f9aa16965fb58467a898843b4f50d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner.h @@ -0,0 +1,135 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include +#include + +// Forward declare DynamicLibrary +namespace at { +struct DynamicLibrary; +} + +namespace torch::inductor { +using TensorConstantMap = std::unordered_map; + +class TORCH_API AOTIModelContainerRunner { + public: + AOTIModelContainerRunner() = delete; + AOTIModelContainerRunner(const AOTIModelContainerRunner& other) = delete; + AOTIModelContainerRunner(AOTIModelContainerRunner&& other) = delete; + AOTIModelContainerRunner& operator=(const AOTIModelContainerRunner& other) = + delete; + AOTIModelContainerRunner& operator=(AOTIModelContainerRunner&& other) = + delete; + virtual ~AOTIModelContainerRunner(); + + std::vector run( + const std::vector& inputs, + void* stream_handle = nullptr); + + // boxed_run will steal the ownership of the input tensors + std::vector boxed_run( + std::vector&& inputs, + void* stream_handle = nullptr); + + std::unordered_map getConstantNamesToOriginalFQNs() + const; + std::unordered_map getConstantNamesToDtypes() const; + + const std::unordered_map extract_constants_map( + bool use_inactive) const; + void update_inactive_constant_buffer(const TensorConstantMap& const_map); + void update_constant_buffer( + std::unordered_map& tensor_map, + bool use_inactive, + bool validate_full_updates, + bool user_managed = false); + void update_constant_buffer( + const TensorConstantMap& const_map, + bool use_inactive, + bool validate_full_updates, + bool user_managed = false); + void run_const_fold( + bool use_inactive, + AOTInductorStreamHandle cuda_stream_handle = nullptr); + void swap_constant_buffer(); + void free_inactive_constant_buffer(); + + std::vector get_call_spec(); + + protected: + AOTIModelContainerRunner( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir, + const bool run_single_threaded); + + virtual std::vector run_impl( + std::vector& input_handles, + void* stream_handle); + + std::unique_ptr model_so_; + decltype(&AOTInductorModelContainerCreateWithDevice) create_func_{nullptr}; + decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr}; + decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{ + nullptr}; + decltype(&AOTInductorModelContainerRun) run_func_{nullptr}; + decltype(&AOTInductorModelContainerGetNumConstants) get_num_constants_func_{ + nullptr}; + decltype(&AOTInductorModelContainerGetConstantName) get_constant_name_func_{ + nullptr}; + decltype(&AOTInductorModelContainerGetConstantOriginalFQN) + get_constant_original_fqn_func_{nullptr}; + decltype(&AOTInductorModelContainerGetConstantDtype) get_constant_dtype_func_{ + nullptr}; + decltype(&AOTInductorModelContainerExtractConstantsMap) + extract_constants_map_func_{nullptr}; + decltype(&AOTInductorModelContainerUpdateUserManagedConstantBuffer) + update_user_managed_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerUpdateConstantBuffer) + update_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer) + update_inactive_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerRunConstantFolding) run_const_fold_func_{ + nullptr}; + decltype(&AOTInductorModelContainerSwapConstantBuffer) + swap_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerFreeInactiveConstantBuffer) + free_inactive_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerGetCallSpec) get_call_spec_func_{nullptr}; + + AOTInductorModelContainerHandle container_handle_ = nullptr; + + AOTIProxyExecutorHandle proxy_executor_handle_; + + private: + std::unique_ptr proxy_executor_; +}; + +using CreateAOTIModelRunnerFunc = std::unique_ptr (*)( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& bin_dir, + const bool run_single_threaded); + +// Return a global map "device name" -> "aoti model runner create function" for +// all registered in AOTI external backends +TORCH_API std::unordered_map& +getAOTIModelRunnerRegistry(); + +// To register a new external backend in AOTI one needs to create an instance of +// this struct. It is not thread-safe. Because it is expected to be called +// during the initialization of the program. +struct TORCH_API RegisterAOTIModelRunner{RegisterAOTIModelRunner( + const std::string& name, + CreateAOTIModelRunnerFunc create_aoti_model_runner_fn){ + getAOTIModelRunnerRegistry()[name] = create_aoti_model_runner_fn; +} // namespace torch::inductor +} +; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..a15485928b9eef297c264ed267ca113dd934920c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h @@ -0,0 +1,18 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include + +namespace torch::inductor { +class TORCH_API AOTIModelContainerRunnerCpu : public AOTIModelContainerRunner { + public: + AOTIModelContainerRunnerCpu( + const std::string& model_so_path, + size_t num_models = 1, + const bool run_single_threaded = false); + + ~AOTIModelContainerRunnerCpu() override; +}; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..83433951d5d7b6d0974da6951dd2623c5e9d53e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h @@ -0,0 +1,35 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +namespace torch::inductor { + +// NOTICE: Following APIs are subject to change due to active development +// We provide NO BC guarantee for these APIs +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class TORCH_CUDA_CPP_API AOTIModelContainerRunnerCuda + : public AOTIModelContainerRunner { + public: + // @param device_str: cuda device string, e.g. "cuda", "cuda:0" + AOTIModelContainerRunnerCuda( + const std::string& model_so_path, + size_t num_models = 1, + const std::string& device_str = "cuda", + const std::string& cubin_dir = "", + const bool run_single_threaded = false); + + ~AOTIModelContainerRunnerCuda() override; + + std::vector run_impl( + std::vector& input_handles, + void* stream_handle) override; + + std::vector run_with_cuda_stream( + const std::vector& inputs, + const at::cuda::CUDAStream& cuda_stream); +}; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_mps.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_mps.h new file mode 100644 index 0000000000000000000000000000000000000000..37cb3e5f8b1c4776bb084dc07d8f2a7a1835c1b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_mps.h @@ -0,0 +1,18 @@ +#if defined(__APPLE__) +#pragma once + +#include + +namespace torch::inductor { +class TORCH_API AOTIModelContainerRunnerMps : public AOTIModelContainerRunner { + public: + AOTIModelContainerRunnerMps( + const std::string& model_so_path, + size_t num_models = 1, + const bool run_single_threaded = false); + + ~AOTIModelContainerRunnerMps() override; +}; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..e93d98d49abd87f5e7acb1aa6d843b67b6d81779 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h @@ -0,0 +1,37 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +namespace torch::inductor { + +// NOTICE: Following APIs are subject to change due to active development +// We provide NO BC guarantee for these APIs + +// HERE we use C10_EXPORT because libtorch_python needs this Symbol be exported. +// And `TORCH_API and `TORCH_XPU_API`` do not export the symbol in Windows +// build. +class C10_EXPORT AOTIModelContainerRunnerXpu : public AOTIModelContainerRunner { + public: + // @param device_str: xpu device string, e.g. "xpu", "xpu:0" + AOTIModelContainerRunnerXpu( + const std::string& model_so_path, + size_t num_models = 1, + const std::string& device_str = "xpu", + const std::string& kernel_bin_dir = "", + const bool run_single_threaded = false); + + ~AOTIModelContainerRunnerXpu() override; + + std::vector run_impl( + std::vector& input_handles, + void* stream_handle) override; + + std::vector run_with_xpu_stream( + const std::vector& inputs, + const at::xpu::XPUStream& xpu_stream); +}; + +} // namespace torch::inductor +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/pybind.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..3797c7e1e6a2798dd238cae06ff3db6eb990495b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runner/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIRunnerBindings(PyObject* module); + +} // namespace torch::inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..6ca660093c67d99dbf0e7635f66989d549055ad3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h @@ -0,0 +1,242 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace torch::aot_inductor { + +using MiniIntArrayRef = MiniArrayRef; + +static_assert( + sizeof(MiniIntArrayRef) == sizeof(void*) + sizeof(size_t), + "changing the size of MiniArrayRef breaks ABI compatibility!"); + +inline bool is_contiguous_strides_for_shape( + int64_t ndim, + const int64_t* strides_ptr, + const int64_t* sizes_ptr) { + int64_t z = 1; + for (int64_t d = ndim - 1; d >= 0; d--) { + const auto& size_d = sizes_ptr[d]; + if (size_d != 1) { + if (strides_ptr[d] == z) { + z *= size_d; + } else { + return false; + } + } + } + return true; +} + +// Shim for AOTI generated code to pretend a raw array works like an +// AtenTensorHandle. +template +class ArrayRefTensor { + public: + ArrayRefTensor() = default; + + explicit ArrayRefTensor( + MiniArrayRef arr, + MiniArrayRef sizes, + MiniArrayRef strides, + int32_t device_type, + int32_t device_idx) + : arrayRef_(arr), + sizes_(sizes), + strides_(strides), + device_type_(device_type), + device_idx_(device_idx) { + assert(sizes.size() == strides.size()); + assert(is_contiguous_strides_for_shape( + sizes.size(), strides.data(), sizes.data())); + } + + AtenTensorHandle expensiveCopyToTensor() const { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided( + sizes_.size(), + sizes_.data(), + strides_.data(), + aoti_torch_dtype>(), + device_type_, + device_idx_, + &result)); + void* dataPtr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(result, &dataPtr)); + std::memcpy(dataPtr, data(), numel() * sizeof(T)); + return result; + } + + // We need to look the same as RAIIAtenTensorHandle, which returns + // an owning AtenTensorHandle from release(). So, we allocate one! + AtenTensorHandle release() { + return expensiveCopyToTensor(); + } + + AtenTensorHandle borrowAsTensor() const { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( + data(), + sizes_.size(), + sizes_.data(), + strides_.data(), + 0, + aoti_torch_dtype>(), + device_type_, + device_idx_, + &result, + aoti_torch_layout_strided(), + nullptr, + 0)); + return result; + } + + // We don't need to free any memory. + void reset() {} + + auto sizes() const { + return sizes_; + } + + auto strides() const { + return strides_; + } + + auto device_type() const { + return device_type_; + } + + auto device_idx() const { + return device_idx_; + } + + T* data() const { + return arrayRef_.data(); + } + + auto numel() const { + return arrayRef_.size(); + } + + void set_arrayref(MiniArrayRef new_arrayref) { + arrayRef_ = new_arrayref; + } + + private: + MiniArrayRef arrayRef_; + // We expect generated code to have statically available sizes & + // strides for us. + MiniArrayRef sizes_; + MiniArrayRef strides_; + int32_t device_type_ = 0; + int32_t device_idx_ = 0; + // We continue to zero-initialize this field in case we repurpose + // the space later; having predictable contents can only help. + int32_t unusedDoNotRemoveForABICompatibility_ = 0; +}; + +static_assert( + sizeof(ArrayRefTensor) == + 3 * sizeof(MiniIntArrayRef) + 3 * sizeof(int32_t) + + (alignof(ArrayRefTensor) > 4 ? sizeof(int32_t) : 0), + "changing the size of ArrayRefTensor breaks ABI compatibility!"); + +template +inline ArrayRefTensor reinterpret_tensor_wrapper( + const ArrayRefTensor& self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset) { + // REVIEW: we should add a way to build the DSO in debug mode during + // tests so we can have checks like this! + assert(is_contiguous_strides_for_shape(ndim, strides_ptr, sizes_ptr)); + return ArrayRefTensor( + MiniArrayRef( + self.data() + storage_offset, self.numel() - storage_offset), + MiniArrayRef(sizes_ptr, ndim), + MiniArrayRef(strides_ptr, ndim), + self.device_type(), + self.device_idx()); +} + +template +inline T* get_data_ptr_wrapper(ArrayRefTensor& tensor) { + return tensor.data(); +} + +template +inline T* get_data_ptr_wrapper(const MiniArrayRef& arr) { + return arr.data(); +} + +template +inline const ArrayRefTensor& unwrap_raii_handle_if_needed( + const ArrayRefTensor& tensor) { + return tensor; +} + +template +inline ArrayRefTensor& unwrap_raii_handle_if_needed( + ArrayRefTensor& tensor) { + return tensor; +} + +template +inline const ArrayRefTensor& wrap_with_raii_handle_if_needed( + const ArrayRefTensor& tensor) { + return tensor; +} + +template +inline ArrayRefTensor& wrap_with_raii_handle_if_needed( + ArrayRefTensor& tensor) { + return tensor; +} + +template +inline ArrayRefTensor wrap_with_raii_handle_if_needed( + ArrayRefTensor&& tensor) { + return std::move(tensor); +} + +template +inline RAIIAtenTensorHandle expensive_copy_to_tensor_if_needed( + const ArrayRefTensor& tensor) { + return tensor.expensiveCopyToTensor(); +} + +inline AtenTensorHandle expensive_copy_to_tensor_if_needed( + AtenTensorHandle handle) { + return handle; +} + +template +const T& copy_arrayref_tensor_to_tensor(const T& t) { + return t; +} + +template +RAIIAtenTensorHandle copy_arrayref_tensor_to_tensor( + const ArrayRefTensor& art) { + return art.expensiveCopyToTensor(); +} + +template +const T& borrow_arrayref_tensor_as_tensor(const T& t) { + return t; +} + +template +RAIIAtenTensorHandle borrow_arrayref_tensor_as_tensor( + const ArrayRefTensor& art) { + return art.borrowAsTensor(); +} + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/constant_type.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/constant_type.h new file mode 100644 index 0000000000000000000000000000000000000000..053eed728fb05301db7cd42ae81d0949372d5e77 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/constant_type.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. + +namespace torch::aot_inductor { + +enum ConstantType : uint8_t { + Unknown = 0, + Parameter = 1, + Buffer = 2, + TensorConstant = 3, + FoldedConstant = 4, +}; + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/device_utils.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/device_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7b48f493af20f2d6fab66ed42f3d7d98345d272b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/device_utils.h @@ -0,0 +1,67 @@ +#pragma once + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. + +#ifdef USE_CUDA + +// FIXME: Currently, CPU and CUDA backend are mutually exclusive. +// This is a temporary workaround. We need a better way to support +// multi devices. + +#include +#include + +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ + do { \ + const cudaError_t code = EXPR; \ + const char* msg = cudaGetErrorString(code); \ + if (code != cudaSuccess) { \ + throw std::runtime_error( \ + std::string("CUDA error: ") + std::string(msg)); \ + } \ + } while (0) + +namespace torch::aot_inductor { + +using DeviceStreamType = cudaStream_t; + +} // namespace torch::aot_inductor + +#elif defined(USE_XPU) +#include +#include +#include +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ + do { \ + const ze_result_t status = EXPR; \ + if (status != ZE_RESULT_SUCCESS) { \ + std::stringstream ss; \ + ss << "L0 runtime error: " << std::hex << std::uppercase << status; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + +namespace torch::aot_inductor { + +using DeviceStreamType = sycl::queue*; + +} // namespace torch::aot_inductor + +#else + +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ + bool ok = EXPR; \ + if (!ok) { \ + throw std::runtime_error("CPU runtime error"); \ + } + +namespace torch::aot_inductor { + +using DeviceStreamType = void*; + +} // namespace torch::aot_inductor + +#endif // USE_CUDA diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/interface.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/interface.h new file mode 100644 index 0000000000000000000000000000000000000000..f2b29049bb8112e15c7bd3c03df345f743245785 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/interface.h @@ -0,0 +1,230 @@ +#pragma once + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +extern "C" { +struct AOTInductorModelOpaque; +using AOTInductorModelHandle = AOTInductorModelOpaque*; + +struct AOTInductorModelContainerOpaque; +using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; + +struct AOTInductorStreamOpaque; +using AOTInductorStreamHandle = AOTInductorStreamOpaque*; + +struct AOTInductorConstantMap; +using AOTInductorConstantMapHandle = AOTInductorConstantMap*; + +// TODO: Deprecate this API. This was kept for BC compatibility. +// Please use AOTInductorModelContainerCreateWithDevice instead. +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + bool is_cpu, + const char* cubin_dir); + +// Creates an AOTInductor model container. The parameter num_models +// specifies the number of model instances that may be run concurrently for +// the same input model. +// `device_str` MUST NOT be nullptr. It must be a valid device string, e.g. +// "cpu", "cuda", "cuda:0", etc. If the device index is not specified for CUDA +// device, runtime will use the device index returned by +// "cudaGetDevice(&device_idx)" +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir); + +// Deletes the AOTInductor model container. +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle); + +// Runs the inference. +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Single-threaded variant of previous. +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Retrieves the number of constants for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +// Retrieves a constant's name. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name); + +// Retrieves a constant's original FQN. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn); + +// Retrieves whether a constant is from folded. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, + size_t idx, + bool* from_folded); + +// Retrieves the inductor constant type. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantType( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* type); + +// Retrieves a constant's dtype. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* dtype); + +// Retrieves a constant's data size. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( + AOTInductorModelContainerHandle container_handle, + size_t idx, + size_t* data_size); + +// Extract the constants that is being used in the container. +AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive); + +// Setup the constant buffer in model container with provided ConstantMap. +// The ConstantMap is user managed, and the user would retain ownership. +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update); + +// Setup the constant buffer in model container with provided ConstantMap +// use_inactive should be set as true if the inactive buffer is to be updated. +// validate_full_update checks if all constants are included in the ConstantMap +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update); + +// Setup the inactive constant buffer in model container with provided +// ConstantMap +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Free the inactive constant buffer in model container. +AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle); + +// Run constant folding on constant buffer. +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, + bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Swap the constant buffer being used to the inactive one. +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle); + +// Retrieves the number of inputs for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_inputs); + +// Retrieves the input name at the given index. +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** ret_input_names); + +// Retrieves the number of outputs for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_outputs); + +// Retrieves the output name at the given index. +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, + size_t output_idx, + const char** ret_output_names); + +// Creates an AOTInductorModel instance. This is a thin and light wrapper +// around the compiled model; it doesn't handle concurrency, queueing, device +// management, etc. Use this if bare-metal performance is needed and you are +// willing to handle other "management" aspects yourself. +// +// constant_map_handle is an opaque type to satisfy the C ABI. It should be a +// std::unordered_map*. +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle* model_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Run an AOTInductorModel (see AOTInductorModelCreate for when one should use +// this function versus AOTInductorModelContainerRun). +AOTIRuntimeError AOTInductorModelRun( + AOTInductorModelHandle model_handle, + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles); + +// Replace AOTInductorModel's constant map. Note it doesn't handle concurrency +// so be sure to handle ordering if AOTInductorModelRun is ran concurrently. +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Delete an AOTInductorModel created by AOTInductorModelCreate. +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs); + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, + const char** in_spec, + const char** out_spec); + +} // extern "C" diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/mini_array_ref.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/mini_array_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..84a7dddb77adeda57254f7e8820eb4c01e9ffe27 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/mini_array_ref.h @@ -0,0 +1,160 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::aot_inductor { + +// Can't use c10::ArrayRef because it's not truly header-only and +// pulls in other c10 headers. This is (sadly) copy-pasted and +// adapted. +template +class MiniArrayRef final { + public: + using iterator = T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + private: + /// The start of the array, in an external buffer. + T* Data; + + /// The number of elements. + size_type Length; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty MiniArrayRef. + /* implicit */ constexpr MiniArrayRef() : Data(nullptr), Length(0) {} + + /// Construct an MiniArrayRef from a single element. + // TODO Make this explicit + constexpr MiniArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an MiniArrayRef from a pointer and length. + constexpr MiniArrayRef(T* data, size_t length) : Data(data), Length(length) {} + + /// Construct an MiniArrayRef from a range. + constexpr MiniArrayRef(T* begin, T* end) : Data(begin), Length(end - begin) {} + + template < + typename Container, + typename = std::enable_if_t().data())>, + T*>>> + /* implicit */ MiniArrayRef(Container& container) + : Data(container.data()), Length(container.size()) {} + + /// Construct an MiniArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because MiniArrayRef can't work on a std::vector + // bitfield. + template + /* implicit */ MiniArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert( + !std::is_same_v, + "MiniArrayRef cannot be constructed from a std::vector bitfield."); + } + + /// Construct an MiniArrayRef from a std::array + template + /* implicit */ constexpr MiniArrayRef(std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct an MiniArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-array*) + /* implicit */ constexpr MiniArrayRef(T (&Arr)[N]) : Data(Arr), Length(N) {} + + // /// Construct an MiniArrayRef from an empty C array. + /* implicit */ constexpr MiniArrayRef(const volatile void* Arr) + : Data(nullptr), Length(0) {} + + /// Construct an MiniArrayRef from a std::initializer_list. + /* implicit */ constexpr MiniArrayRef(const std::initializer_list& Vec) + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { + return Data; + } + constexpr iterator end() const { + return Data + Length; + } + + // These are actually the same as iterator, since MiniArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { + return Data; + } + constexpr const_iterator cend() const { + return Data + Length; + } + + constexpr reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + constexpr reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + /// empty - Check if the array is empty. + constexpr bool empty() const { + return Length == 0; + } + + constexpr T* data() const { + return Data; + } + + /// size - Get the array size. + constexpr size_t size() const { + return Length; + } + + /// equals - Check for element-wise equality. + constexpr bool equals(MiniArrayRef RHS) const { + return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T& operator[](size_t Index) const { + return Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MiniArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MiniArrayRef>& operator=( + std::initializer_list) = delete; +}; + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/model.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/model.h new file mode 100644 index 0000000000000000000000000000000000000000..1c12f018cd423d99c8388b6717ab3d4e46e05aa8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/model.h @@ -0,0 +1,781 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include +#ifdef USE_MPS +#include +#endif // USE_MPS +#ifdef USE_XPU +#include +#else +#include +#endif // USE_XPU +#include + +#define AOTI_RUNTIME_CHECK(EXPR, MSG) \ + do { \ + bool ok = EXPR; \ + if (!ok) { \ + throw std::runtime_error(MSG); \ + } \ + } while (0) + +// At codegen time, we write out a binary file called constants.bin. +// We then turn the raw binary to an object file that exposes this +// symbol and link it into the final .so. +// For information on the binary format, see `man objcopy`, under +// the "binary-architecture" flag: +// https://man7.org/linux/man-pages/man1/objcopy.1.html +// todo: use #embed in C++ 23 once available +// The constants are NOT readonly because they may be mutated. +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_start[]; +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_end[]; + +#if defined(USE_CUDA) || defined(USE_XPU) +// Compute required blob size with 64-alignment if on GPU. +#define AOTI_CONST_ALIGNMENT 64 +#else +// Use 64-alignment (use something >=64)for better performance on CPU. +#define AOTI_CONST_ALIGNMENT 64 +#endif + +namespace { + +using RAIIDataPtr = std::unique_ptr>; + +#ifdef USE_CUDA + +RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { + void* data_ptr; + AOTI_RUNTIME_DEVICE_CHECK(cudaMalloc((void**)&data_ptr, num_bytes)); + auto deleter = [](void* ptr) { AOTI_RUNTIME_DEVICE_CHECK(cudaFree(ptr)); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#elif defined(USE_XPU) + +RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { + sycl::queue* queue_ptr = nullptr; + aoti_torch_get_current_sycl_queue((void**)&queue_ptr); + void* data_ptr = sycl::malloc_device(num_bytes, *queue_ptr); + auto deleter = [queue_ptr](void* ptr) { sycl::free(ptr, *queue_ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#elif defined(USE_MPS) + +RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { + void* data_ptr = nullptr; + aoti_torch_mps_malloc(&data_ptr, num_bytes); + auto deleter = [](void* ptr) { aoti_torch_mps_free(ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#else + +RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { + void* data_ptr = std::malloc(num_bytes); + if (!data_ptr) { + throw std::bad_alloc(); + } + auto deleter = [](void* ptr) { std::free(ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#endif // USE_CUDA + +} // anonymous namespace + +namespace torch::aot_inductor { + +using ConstantMap = + std::unordered_map; + +// valid device strs are: cpu, cuda, cuda:0, cuda:1, ... +// Update the list here if more devices are supported in the future +inline void parse_device_str( + const std::string& device_str, + int32_t& device_type, + int32_t& device_idx) { + std::regex re("(cpu|cuda|xpu|mps)(:([0-9]+))?"); + std::smatch sm; + bool matched = std::regex_match(device_str, sm, re); + AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str); + + if (sm[1].str() == "cpu") { + device_type = aoti_torch_device_type_cpu(); + } else if (sm[1].str() == "cuda") { + device_type = aoti_torch_device_type_cuda(); +#ifdef USE_XPU + } else if (sm[1].str() == "xpu") { + device_type = aoti_torch_device_type_xpu(); +#endif +#ifdef USE_MPS + } else if (sm[1].str() == "mps") { + device_type = aoti_torch_device_type_mps(); +#endif + } else { + AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str); + } + + if (sm[3].matched) { + device_idx = stoi(sm[3].str()); + } else { + device_idx = -1; + } +} + +// Defines the base class for AOTInductorModel, which is generated by the +// AOTInductor cpp codegen. Since we do not need dynamic dispatch, we rely +// on curiously recurring template pattern (CRTP) to save some runtime +// v-table overhead. The generated AOTInductorModel is specialized with +// methods such as run_impl. +template +class AOTInductorModelBase { + public: + AOTInductorModelBase( + size_t num_inputs, + size_t num_outputs, + size_t num_constants, + const std::string& device_str, + std::optional cubin_dir, + bool include_weights = true) + : inputs_info_(num_inputs), + outputs_info_(num_outputs), + constants_info_(num_constants), + cubin_dir_(std::move(cubin_dir)), + include_weights(include_weights) { + parse_device_str(device_str, device_type_, device_idx_); + +#ifdef USE_CUDA + if (device_idx_ == -1) { + AOTI_RUNTIME_DEVICE_CHECK(cudaGetDevice(&device_idx_)); + } else { + // If device_idx_ is passed in, we need to set the current device to it + AOTI_RUNTIME_DEVICE_CHECK(cudaSetDevice(device_idx_)); + } +#endif // USE_CUDA +#ifdef USE_XPU + if (device_idx_ == -1) { + aoti_torch_get_current_xpu_device(&device_idx_); + } else { + aoti_torch_set_current_xpu_device(device_idx_); + } +#endif // USE_XPU +#ifdef USE_MPS + if (device_idx_ == -1) { + device_idx_ = 0; + } +#endif // USE_MPS + } + + // NOLINTNEXTLINE(modernize-use-equals-default) + ~AOTInductorModelBase() { +#ifdef USE_CUDA + if (run_finished_) { + auto code = cudaEventDestroy(*run_finished_); + if (code != cudaSuccess) { + std::cerr << "Failed to destroy CUDA event in AOTInductor model: " + << cudaGetErrorString(code) << std::endl; + } + } +#endif // USE_CUDA +#ifdef USE_XPU + if (run_finished_) { + (*run_finished_)->wait_and_throw(); + delete *run_finished_; + } +#endif // USE_XPU + } + + AOTInductorModelBase(AOTInductorModelBase&&) = delete; + AOTInductorModelBase& operator=(AOTInductorModelBase&&) = delete; + AOTInductorModelBase(const AOTInductorModelBase&) = delete; + AOTInductorModelBase& operator=(const AOTInductorModelBase&) = delete; + + void run( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { +#ifdef USE_CUDA + if (!run_finished_) { + cudaEvent_t run_finished; + AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished)); + run_finished_.emplace(run_finished); + } +#elif defined(USE_XPU) + if (run_finished_) { + (*run_finished_)->wait_and_throw(); + delete *run_finished_; + run_finished_.reset(); + } +#else // !USE_CUDA && !USE_XPU + run_finished_ = false; +#endif + + auto* model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + +#ifdef USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream)); +#elif defined(USE_XPU) + run_finished_ = std::make_optional(new sycl::event( + static_cast(stream)->ext_oneapi_submit_barrier())); +#else // !USE_CUDA && !USE_XPU + run_finished_ = true; +#endif // USE_CUDA + } + + // Non-thread-aware variant of run(). Obviously unsafe to use in a threaded + // environment :) + void run_single_threaded( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { + // don't bother with any of the run_finished stuff; this is unsafe to call + // in a threaded context + auto* model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + } + + std::unordered_map run_const_fold( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization = false) { +#ifdef USE_CUDA + if (!run_finished_) { + cudaEvent_t run_finished; + AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished)); + run_finished_.emplace(run_finished); + } +#elif defined(USE_XPU) + if (run_finished_) { + (*run_finished_)->wait_and_throw(); + delete *run_finished_; + run_finished_.reset(); + } +#else // !USE_CUDA && !USE_XPU + run_finished_ = false; +#endif + + auto* model = static_cast(this); + auto folded_constants = + model->const_run_impl(stream, proxy_executor, initialization); + +#ifdef USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream)); +#elif defined(USE_XPU) + // sycl::queue* queue_ptr = nullptr; + // aoti_torch_get_current_sycl_queue((void**)&queue_ptr); + run_finished_ = std::make_optional(new sycl::event( + static_cast(stream)->ext_oneapi_submit_barrier())); + +#else // !USE_CUDA && !USE_XPU + run_finished_ = true; +#endif // USE_CUDA + + return folded_constants; + } + + void load_constants() { + size_t num_constants = this->num_constants(); + size_t num_folded_constants = this->num_folded_constants(); + constants_map_->reserve(num_constants); + + std::vector constants_internal_offset( + num_constants - num_folded_constants); + size_t blob_size = 0; + compute_constant_blob(blob_size, constants_internal_offset); + if (!include_weights) { + return; + } +#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS) + constant_blob_ = RAII_gpuMalloc(blob_size); +#else + constant_blob_ = RAII_cpuMalloc(blob_size); +#endif + + size_t bytes_read = 0; + for (size_t i = 0; i < num_constants; i++) { + bool from_folded = this->constant_from_folded(i); + if (from_folded) { + continue; + } + std::string name = this->constant_name(i); + size_t data_size = this->constant_data_size(i); + uint8_t* internal_ptr = (data_size != 0) + ? constant_ptr( + constants_internal_offset[i], + bytes_read, + data_size, + /* skip_copy = */ false) + : nullptr; + bytes_read += data_size; + + // Create at::Tensor from copied memory. + auto dtype = this->constant_dtype(i); + auto ndim = this->constant_ndim(i); + auto size = this->constant_shape(i); + auto stride = this->constant_stride(i); +#ifdef USE_MPS + auto offset = this->constant_offset(i) + + (constants_internal_offset[i] / aoti_torch_dtype_element_size(dtype)); +#else + auto offset = this->constant_offset(i); +#endif + auto layout = this->constant_layout(i); + auto opaque_metadata_ptr = this->opaque_metadata(i); + auto opaque_metadata_size = this->opaque_metadata_size(i); + + AtenTensorHandle tensor_handle = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( + internal_ptr, + ndim, + size, + stride, + offset, + dtype, + device_type_, + device_idx_, + &tensor_handle, + layout, + opaque_metadata_ptr, + opaque_metadata_size)); + constants_map_->emplace(std::move(name), tensor_handle); + } + if (constants_map_) { + this->update_constants_array_from_map(); + } + } + + RAIIDataPtr&& release_constant_blob() { + return std::move(constant_blob_); + } + + std::shared_ptr> get_constants_array() { + return constants_; + } + + int32_t get_device_type() const { + return device_type_; + } + + int32_t get_device_idx() const { + return device_idx_; + } + + uint8_t* constant_ptr( + size_t constant_offset, + size_t bytes_read, + size_t data_size, + bool skip_copy) { + auto* constants_ptr = static_cast(constant_blob_.get()); + uint8_t* internal_ptr = constants_ptr + constant_offset; + // TODO: Handle shared storage case. + if (!skip_copy) { +#ifdef USE_XPU + sycl::queue* queue_ptr = nullptr; + aoti_torch_get_current_sycl_queue((void**)&queue_ptr); + queue_ptr + ->memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size) + .wait(); +#elif USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy( + internal_ptr, + _get_constants_start() + bytes_read, + data_size, + cudaMemcpyHostToDevice)); +#elif USE_MPS + aoti_torch_mps_memcpy( + constants_ptr, + constant_offset, + bytes_read, + data_size, + _get_constants_start()); + return constants_ptr; +#else + memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size); +#endif + } + return internal_ptr; + } + + void compute_constant_blob( + size_t& blob_size, + std::vector& constants_internal_offset) { + size_t num_constants = this->num_constants(); + blob_size = 0; + size_t curr_idx = 0; + for (size_t i = 0; i < num_constants; i++) { + if (this->constant_from_folded(i)) { + continue; + } + size_t data_size = this->constant_data_size(i); + if (data_size % AOTI_CONST_ALIGNMENT) { + data_size = AOTI_CONST_ALIGNMENT + + (data_size / AOTI_CONST_ALIGNMENT) * AOTI_CONST_ALIGNMENT; + } + constants_internal_offset[curr_idx++] = blob_size; + blob_size += data_size; + } + } + + size_t num_inputs() const { + return inputs_info_.size(); + } + + size_t num_outputs() const { + return outputs_info_.size(); + } + + size_t num_constants() const { + return constants_info_.size(); + } + + size_t num_folded_constants() const { + size_t total_consts = this->num_constants(); + size_t folded_consts = 0; + for (size_t i = 0; i < total_consts; i++) { + if (this->constant_from_folded(i)) { + folded_consts++; + } + } + return folded_consts; + } + + const char* input_name(int64_t idx) const { + return inputs_info_.at(idx).name; + } + + const char* output_name(int64_t idx) const { + return outputs_info_.at(idx).name; + } + + const char* constant_name(int64_t idx) const { + return constants_info_.at(idx).name; + } + + size_t constant_ndim(int64_t idx) { + return constants_info_.at(idx).shape.size(); + } + + const int64_t* constant_shape(int64_t idx) const { + return constants_info_.at(idx).shape.data(); + } + + const int64_t* constant_stride(int64_t idx) const { + return constants_info_.at(idx).stride.data(); + } + + int32_t constant_dtype(int64_t idx) const { + return constants_info_.at(idx).dtype; + } + + int32_t constant_layout(int64_t idx) const { + return constants_info_.at(idx).layout; + } + + size_t constant_offset(int64_t idx) const { + return constants_info_.at(idx).offset; + } + + size_t constant_data_size(int64_t idx) const { + return constants_info_.at(idx).data_size; + } + + const char* constant_original_fqn(int64_t idx) const { + return constants_info_.at(idx).original_fqn; + } + + const uint8_t* opaque_metadata(int64_t idx) const { + return constants_info_.at(idx).opaque_metadata.data(); + } + + size_t opaque_metadata_size(int64_t idx) { + return constants_info_.at(idx).opaque_metadata.size(); + } + + bool constant_from_folded(int64_t idx) const { + return constants_info_.at(idx).from_folded; + } + + int32_t constant_type(int64_t idx) const { + return constants_info_.at(idx).type; + } + + const char* get_in_spec() const { + return in_spec_.c_str(); + } + + const char* get_out_spec() const { + return out_spec_.c_str(); + } + + void update_constants_array_from_map() { + if (!constants_map_) { + throw std::runtime_error{ + "constants_map_ was not ready when constants_ is trying to be constructed from it!"}; + } + if (!constants_) { + constants_ = + std::make_shared>(constants_info_.size()); + } else { + constants_->resize(constants_info_.size()); + } + int idx = 0; + for (const auto& info : constants_info_) { + const auto it = constants_map_->find(info.name); + if (it != constants_map_->end()) { + constants_->at(idx) = ConstantHandle(it->second); + } + idx++; + } + } + + void update_constants_map( + std::shared_ptr constants_map, + bool remap_constants_array = true) { + constants_map_ = std::move(constants_map); + if (remap_constants_array) { + update_constants_array_from_map(); + } + } + + // This function allows us to update the constants_ that is used to look up + // the corresponding constant tensor during runtime. + void update_constants_array( + std::shared_ptr> constants_array) { + constants_ = std::move(constants_array); + } + + /// Returns true if the model is complete. + bool is_finished() { +#ifdef USE_CUDA + if (!run_finished_) { + throw std::runtime_error{"Model CUDA event was not initialized"}; + } + + auto event_status = cudaEventQuery(*run_finished_); + if (event_status == cudaSuccess) { + return true; + } else if (event_status == cudaErrorNotReady) { + return false; + } + + throw std::runtime_error( + std::string("The model did not finish successfully. Error: ") + + cudaGetErrorString(cudaGetLastError())); +#elif defined(USE_XPU) + if (!run_finished_) { + throw std::runtime_error{"Model XPU event was not initialized"}; + } + using namespace sycl::info; + return (*run_finished_)->get_info() == + event_command_status::complete; + +#else // !USE_CUDA && !USE_XPU + return run_finished_; +#endif // USE_CUDA + } + + /// Synchronizes completion event. + void wait_for_completion() { +#ifdef USE_CUDA + if (!run_finished_) { + throw std::runtime_error{"Model event was not initialized"}; + } + + AOTI_RUNTIME_DEVICE_CHECK(cudaEventSynchronize(*run_finished_)); +#endif // USE_CUDA +#ifdef USE_XPU + if (!run_finished_) { + throw std::runtime_error{"Model event was not initialized"}; + } + (*run_finished_)->wait_and_throw(); +#endif + } + + protected: + uint8_t* _get_constants_start() { +#ifndef USE_MMAP_SELF + // NOLINTNEXTLINE(*const-cast*) + return const_cast(_binary_constants_bin_start); +#else + if (self_mmap) { + return self_mmap; + } + Dl_info dl_info; + // get pointer to constant which are appended to the binary + AOTI_RUNTIME_CHECK( + dladdr(__func__, &dl_info), "Can't find shared library name"); + int fd = open(dl_info.dli_fname, O_RDONLY); + AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened"); + auto fsize = lseek(fd, 0, SEEK_END); + auto weights_size = + reinterpret_cast(_binary_constants_bin_start)[0]; + auto magic_number = + reinterpret_cast(_binary_constants_bin_start)[1]; + auto weights_offset = fsize - weights_size; + AOTI_RUNTIME_CHECK( + (weights_offset & 0x3fff) == 0, + "weights_offset must be aligned to 16K boundary"); + auto ptr = mmap( + NULL, + weights_size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE, + fd, + weights_offset); + close(fd); + AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed"); + self_mmap = static_cast(ptr); + AOTI_RUNTIME_CHECK( + reinterpret_cast( + self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number, + "Weights data seems corrupt"); + return self_mmap; +#endif + } + struct ParamInfo { + const char* name = nullptr; + }; + + struct ConstInfo { + const char* name = nullptr; + std::vector shape; + std::vector stride; + int32_t dtype{}; + int64_t offset{}; + size_t data_size{}; + int32_t layout{}; + std::vector opaque_metadata; + int64_t opaque_metadata_size{}; + const char* original_fqn = nullptr; + bool from_folded{}; + int32_t type{}; + }; + + std::vector inputs_info_; + std::vector outputs_info_; + std::vector constants_info_; + std::string in_spec_; + std::string out_spec_; + + std::shared_ptr constants_map_; + std::shared_ptr> constants_; + + // Holds the blob storage for constants' at::Tensor. + RAIIDataPtr constant_blob_; + +#ifdef USE_MMAP_SELF + uint8_t* self_mmap = NULL; +#endif + + // A directory with CUDA binary files, e.g. compiled kernels, etc. + const std::optional cubin_dir_; + + // This is the flag that implies whether the weight is included in the model. + // If True, we would prepare the weight when loading the model, otherwise the + // model will be loaded without weights, and need to be provided by the user. + bool include_weights; + + // Record if the model finishes an inference run so that its owning + // AOTModelContainer can reuse this instance. +#ifdef USE_CUDA + std::optional run_finished_; +#elif defined(USE_XPU) + std::optional run_finished_; +#else // !USE_CUDA + bool run_finished_{}; +#endif + + // Generated model uses this device index to create CUDA guards. + int32_t device_type_{}; + int32_t device_idx_{}; +}; + +// Codegen-ed classes can derive from this to keep pointers to loaded kernels. +class AOTInductorModelKernelsBase { + public: + virtual ~AOTInductorModelKernelsBase() = default; +}; + +class AOTInductorModel : public AOTInductorModelBase { + public: + AOTInductorModel( + std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir); + + std::unordered_map const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization = false); + + void _const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + + void run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + + template + Outputs run_impl_minimal_arrayref_interface( + const Inputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + + static std::unique_ptr Create( + std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir) { + return std::make_unique( + std::move(constants_map), + std::move(constants_array), + device_str, + std::move(cubin_dir)); + } + + private: + std::unique_ptr kernels_; +}; + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/model_container.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/model_container.h new file mode 100644 index 0000000000000000000000000000000000000000..10292f7968a268eef864db4883370290775b592c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/model_container.h @@ -0,0 +1,762 @@ +#pragma once + +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +namespace torch::aot_inductor { +// The state transition is done by: +// (1) NONE state: The default state when created. This state should only exist +// when model_container is created and no constants are being loaded or updated. +// (2) INITIALIZED state: This state get set whenever we load the constants into +// the buffer. This could be done by load_constants or update_constants_buffer. +// (3) FOLDED state: This state should transition from INITIALIZED after +// const_fold is being invoked. +enum class ConstantState : uint8_t { NONE, INITIALIZED, FOLDED, UNKNOWN }; + +inline std::string toStringConstantState(ConstantState state) { + switch (state) { + case ConstantState::NONE: + return "ConstantState::NONE"; + case ConstantState::INITIALIZED: + return "ConstantState::INITIALIZED"; + case ConstantState::FOLDED: + return "ConstantState::FOLDED"; + case ConstantState::UNKNOWN: + return "ConstantState::UNKNOWN"; + default: + return "Unknown enum class state for ConstantState"; + } +} + +class AOTInductorModelContainer { + public: + AOTInductorModelContainer( + size_t num_models, + const std::string& device_str, + const std::optional& cubin_dir = std::nullopt) { + constants_map_ = std::make_shared(); + constants_array_ = std::make_shared>(); + + models_.reserve(num_models); + available_models_.reserve(num_models); + for (size_t i = 0; i < num_models; ++i) { + models_.push_back(AOTInductorModel::Create( + constants_map_, constants_array_, device_str, cubin_dir)); + available_models_.push_back(models_.back().get()); + } + + // Note that the all following fields (input_names_, output_names, + // etc) can be filled in by the AOT + // codegen. However, we choose to query such information from + // the owned AOTInductorModel for a couple of reasons: + // * simplify the codegen templates + // * reduce information fragmentation and duplication + // * the initialization process below is done only once when the container + // is constructed, so it would have little performance impact + auto* model = available_models_[0]; + size_t num_inputs = model->num_inputs(); + input_names_.reserve(num_inputs); + for (size_t i = 0; i < num_inputs; i++) { + input_names_.emplace_back(model->input_name(static_cast(i))); + } + + size_t num_outputs = model->num_outputs(); + output_names_.reserve(num_outputs); + for (size_t i = 0; i < num_outputs; i++) { + output_names_.emplace_back(model->output_name(static_cast(i))); + } + model->load_constants(); + constant_blob_ = model->release_constant_blob(); + constants_internal_offset_.resize( + model->num_constants() - model->num_folded_constants()); + model->compute_constant_blob(blob_size_, constants_internal_offset_); + constant_folded_ = ConstantState::INITIALIZED; + + for (auto& model : models_) { + model->update_constants_map(constants_map_); + } + + in_spec_ = model->get_in_spec(); + out_spec_ = model->get_out_spec(); + } + + void run( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { + std::shared_lock model_lk(model_exec_mutex_); + auto* model = get_available_model(); + + ConstantState& const_folded = + use_secondary_ ? constant_folded_secondary_ : constant_folded_; + if (const_folded == ConstantState::INITIALIZED) { + // At this point, constant is not ready yet. We need to call constant + // folding before we execute the model. We obtain a unique lock at this + // point to make sure constant is ready for all. + model_lk.unlock(); + std::unique_lock constants_folding_lk(model_exec_mutex_); + // Double locking to make sure constant folding is only ran once. + if (const_folded == ConstantState::INITIALIZED) { + auto folded_const_map = model->run_const_fold( + stream, proxy_executor, /* initialization = */ true); + update_constant_buffer( + std::move(folded_const_map), + /* use_inactive = */ false, + /* validate_full_update = */ false); + const_folded = ConstantState::FOLDED; + } + constants_folding_lk.unlock(); + model_lk.lock(); + } else if (const_folded != ConstantState::FOLDED) { + throw std::runtime_error( + "Unknown constant state: " + toStringConstantState(constant_folded_)); + } + + try { + model->run(input_handles, output_handles, stream, proxy_executor); + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + + { + std::lock_guard lk(models_mutex_); + pending_models_.push_back(model); + } + pending_models_available_.notify_one(); + } + + // Non-thread-aware variant of run(). Obviously unsafe to use in a threaded + // environment :) + void run_single_threaded( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { + auto* model = available_models_[0]; + + ConstantState& const_folded = + use_secondary_ ? constant_folded_secondary_ : constant_folded_; + if (const_folded == ConstantState::INITIALIZED) { + auto folded_const_map = model->run_const_fold( + stream, proxy_executor, /* initialization = */ true); + update_constant_buffer( + std::move(folded_const_map), + /* use_inactive = */ false, + /* validate_full_update = */ false); + const_folded = ConstantState::FOLDED; + } else if (constant_folded_ != ConstantState::FOLDED) { + throw std::runtime_error( + "Unknown constant state: " + toStringConstantState(constant_folded_)); + } + + model->run_single_threaded( + input_handles, output_handles, stream, proxy_executor); + } + + const std::unordered_map extract_constants_map( + bool use_inactive) const { + size_t n_consts = this->num_constants(); + std::unordered_map ret; + ret.reserve(n_consts); + + std::shared_ptr extract_map = constants_map_; + // Essentially a XOR + if (use_inactive != use_secondary_) { + extract_map = constants_map_secondary_; + } + for (size_t idx = 0; idx < n_consts; idx++) { + if (this->constant_from_folded(idx)) { + continue; + } + + auto it = extract_map->find(this->constant_name(idx)); + if (it != extract_map->end()) { + ret.emplace(this->constant_original_fqn(idx), it->second); + continue; + } + } + + return ret; + } + + size_t num_constants() const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->num_constants(); + } + + // retrieve the constant name of constants_info_[idx] + const char* constant_name(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_name(static_cast(idx)); + } + + // retrieve original FQN of constants_info_[idx] + const char* constant_original_fqn(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_original_fqn(static_cast(idx)); + } + + // retrieve whether constant is from folded of constants_info_[idx] + bool constant_from_folded(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_from_folded(static_cast(idx)); + } + + size_t constant_data_size(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_data_size(static_cast(idx)); + } + + // retrieve type of constants_info_[idx] + int32_t constant_type(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_type(static_cast(idx)); + } + + // retrieve dtype of constants_info_[idx] + int32_t constant_dtype(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_dtype(static_cast(idx)); + } + + void run_const_fold( + bool inactive_buffer, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { + AOTInductorModel* model; + ConstantState& const_folded = inactive_buffer == use_secondary_ + ? constant_folded_ + : constant_folded_secondary_; + if (!inactive_buffer) { + // We would need to acquire a unique lock if we want to run constant + // folding on the active buffer. + std::unique_lock constants_folding_lk(model_exec_mutex_); + model = get_available_model(); + try { + auto folded_const_map = model->run_const_fold(stream, proxy_executor); + update_constant_buffer( + std::move(folded_const_map), + /* use_inactive = */ false, + /* validate_full_update = */ false); + const_folded = ConstantState::FOLDED; + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + } else { + std::shared_lock model_lk(model_exec_mutex_); + model = get_available_model(); + + // We swap the constant mapping to the inactive buffer in the model to run + // const run. + auto constants_map = get_constants_map(/* get_inactive= */ true); + auto constants_array = get_constants_array(/* get_inactive= */ true); + + try { + model->update_constants_map( + constants_map, /* remap_constants_array= */ false); + model->update_constants_array(constants_array); + + auto folded_const_map = model->run_const_fold(stream, proxy_executor); + update_constant_buffer( + std::move(folded_const_map), + /* use_inactive = */ true, + /* validate_full_update = */ false); + + // Swap back the model's constants mapping + constants_map = get_constants_map(/* get_inactive= */ false); + constants_array = get_constants_array(/* get_inactive= */ false); + model->update_constants_map( + constants_map, /* remap_constants_array= */ false); + model->update_constants_array(constants_array); + const_folded = ConstantState::FOLDED; + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + } + + { + std::lock_guard lk(models_mutex_); + pending_models_.push_back(model); + } + pending_models_available_.notify_one(); + } + + bool _is_tensor_constant_type(const size_t idx) const { + auto constant_type = models_[0]->constant_type(static_cast(idx)); + // We should skip constants + return constant_type == ConstantType::TensorConstant; + } + + bool _is_buffer_type(const size_t idx) const { + auto constant_type = models_[0]->constant_type(static_cast(idx)); + // Buffer can be optionally skipped, so if it not provided by upstream + // services, it is OK to relax the check. + return constant_type == ConstantType::Buffer; + } + + bool _is_tensor_constant_or_buffer_type(const size_t idx) const { + return _is_tensor_constant_type(idx) || _is_buffer_type(idx); + } + + void assert_all_constants( + const std::unordered_map& constants_map) { + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + if (models_[0]->constant_from_folded(static_cast(idx))) { + continue; + } + + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end()) { + if (_is_tensor_constant_or_buffer_type(idx)) { + // tracing sometimes creates tensors that are non-existent in + // original graph. We could skip those and do a direct copy. + std::cerr << "[WARNING] Found constant or module state buffer " + << constant_name + << " in model, but not provided by user!\n"; + continue; + } + throw std::runtime_error( + std::string("Cannot find constants ") + constant_name + + std::string(" in constants_map!")); + } + } + } + + // We directly take ownership from AtenTensorHandle if constants are moved. + void update_constant_buffer( + std::unordered_map&& constants_map, + bool use_inactive, + bool validate_full_update) { + if (this->num_models() == 0) { + throw std::runtime_error("No model available in container!"); + } + if (validate_full_update) { + assert_all_constants(constants_map); + } + + ConstantState& const_folded = use_inactive == use_secondary_ + ? constant_folded_ + : constant_folded_secondary_; + const_folded = ConstantState::INITIALIZED; + + auto original_constants_map = get_constants_map(!use_inactive); + auto constants_map_to_update = get_constants_map(use_inactive); + + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end() && + !(use_inactive && _is_tensor_constant_type(idx))) { + continue; + } + + AtenTensorHandle tensor; + if (it == constants_map.end()) { + aoti_torch_clone( + original_constants_map->find(constant_name)->second.get(), &tensor); + } else { + tensor = it->second; + } + + constants_map_to_update->insert_or_assign( + constant_name, RAIIAtenTensorHandle(tensor)); + } + // Update the inactive constant array. + update_array_from_map( + get_constants_array(use_inactive), constants_map_to_update); + } + + // This function updates the buffer for storing constants. + // It will update the buffer, the mapping and the array mapping. + void update_constant_buffer( + const std::unordered_map& constants_map, + bool use_inactive, + bool validate_full_update, + bool user_managed = false) { + if (this->num_models() == 0) { + throw std::runtime_error("No model available in container!"); + } + if (validate_full_update) { + assert_all_constants(constants_map); + } + + ConstantState& const_folded = use_inactive == use_secondary_ + ? constant_folded_ + : constant_folded_secondary_; + const_folded = ConstantState::INITIALIZED; + + auto original_constants_map = get_constants_map(!use_inactive); + auto constants_map_to_update = get_constants_map(use_inactive); + + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end() && + !(use_inactive && _is_tensor_constant_or_buffer_type(idx))) { + continue; + } + + AtenTensorHandle tensor; + if (it == constants_map.end()) { + tensor = original_constants_map->find(constant_name)->second.get(); + } else { + tensor = it->second; + } + + if (user_managed) { + // If user managed, we pass in the pointer directly, and skip the + // copy. + constants_map_to_update->insert_or_assign( + constant_name, + MaybeOwningAtenTensorHandle(tensor, /* user_managed = */ true)); + continue; + } + + auto* constants_blob_ptr = + static_cast(get_constant_blob_ptr(use_inactive)); + + // Move the data to container handled blob. + uint8_t* internal_constants_ptr = + constants_blob_ptr + constants_internal_offset_[idx]; + void* user_constant_ptr; + int64_t constant_size; + aoti_torch_get_data_ptr(tensor, &user_constant_ptr); + aoti_torch_get_storage_size(tensor, &constant_size); +#ifdef USE_XPU + sycl::queue* queue_ptr = nullptr; + aoti_torch_get_current_sycl_queue((void**)&queue_ptr); + queue_ptr + ->memcpy(internal_constants_ptr, user_constant_ptr, constant_size) + .wait(); +#elif USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy( + internal_constants_ptr, + user_constant_ptr, + constant_size, + cudaMemcpyDefault)); +#else + memcpy(internal_constants_ptr, user_constant_ptr, constant_size); +#endif + // Generate Tensor from container handled blob. + // We extract stride and offset from provided Tensor since we do not + // guarantee that the tensor is contiguous. + AtenTensorHandle tensor_handle; + int64_t* stride; + int64_t offset; + int device_type = models_[0]->get_device_type(); + int device_idx = models_[0]->get_device_idx(); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(tensor, &offset)); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( + internal_constants_ptr, + models_[0]->constant_ndim(idx), + models_[0]->constant_shape(idx), + stride, + offset, + models_[0]->constant_dtype(idx), + device_type, + device_idx, + &tensor_handle)); + + // Now place the tensor to constants_map. Note at this point the + // ownership of the tensor_handle will be taken over. + constants_map_to_update->insert_or_assign( + constant_name, RAIIAtenTensorHandle(tensor_handle)); + } + // Update the inactive constant array. + update_array_from_map( + get_constants_array(use_inactive), constants_map_to_update); + } + + void update_array_from_map( + const std::shared_ptr>& constants_array, + const std::shared_ptr& constants_map) { + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + if (constants_map->find(models_[0]->constant_name( + static_cast(idx))) != constants_map->end()) { + constants_array->at(idx) = ConstantHandle( + constants_map + ->find(models_[0]->constant_name(static_cast(idx))) + ->second); + } + } + } + + void swap_constant_buffer() { + std::lock_guard unique_lk(model_exec_mutex_); + + auto constants_map = get_constants_map(/* get_inactive= */ true); + auto constants_array = get_constants_array(/* get_inactive= */ true); + + for (auto& model : models_) { + model->update_constants_map( + constants_map, /* remap_constants_array = */ false); + model->update_constants_array(constants_array); + } + + use_secondary_ = !use_secondary_; + } + + void free_inactive_constant_buffer() { + if (use_secondary_) { + constant_folded_ = ConstantState::NONE; + constant_blob_.reset(); + } else { + constant_folded_secondary_ = ConstantState::NONE; + constant_blob_secondary_.reset(); + } + // Free the internally held constants + int num_constants = static_cast(models_[0]->num_constants()); + std::shared_ptr to_free_map = + use_secondary_ ? constants_map_ : constants_map_secondary_; + + for (int i = 0; i < num_constants; i++) { + if (models_[0]->constant_from_folded(i)) { + auto it = to_free_map->find(models_[0]->constant_name(i)); + if (it != to_free_map->end()) { + it->second.reset(); + } + } + } + } + + size_t num_inputs() const { + return input_names_.size(); + } + + size_t num_outputs() const { + return output_names_.size(); + } + + const char* input_name(size_t idx) const { + return input_names_.at(idx).c_str(); + } + + const char* output_name(size_t idx) const { + return output_names_.at(idx).c_str(); + } + + size_t num_models() const { + return models_.size(); + } + + const char* get_in_spec() const { + return in_spec_; + } + + const char* get_out_spec() const { + return out_spec_; + } + + private: + std::vector input_names_; + std::vector output_names_; + const char* in_spec_; + const char* out_spec_; + + // Holds the blob storage for constants' at::Tensor within the container. + // This blob of memory will be managed by the container. + RAIIDataPtr constant_blob_; + RAIIDataPtr constant_blob_secondary_; + + size_t blob_size_; + std::vector constants_internal_offset_; + + // Determine which constants is being used for the model. + // If true, + // constants_map_secondary/constant_blob_secondary/constants_array_secondary + // is being used. + bool use_secondary_{false}; + + // Determine whether we have ran constant folding + ConstantState constant_folded_{ConstantState::NONE}; + ConstantState constant_folded_secondary_{ConstantState::NONE}; + + // Holds the mapping of constants to at::Tensor. + // The underlying data of at::Tensor is in either constant_blob_ (for CUDA). + // or _binary_constants_bin_start (for CPU). + std::shared_ptr constants_map_; + std::shared_ptr constants_map_secondary_; + + // Holds the indexed array of constant for faster lookup during runtime. + std::shared_ptr> constants_array_; + std::shared_ptr> constants_array_secondary_; + + // Holds all the AOTInductorModel instances owned by this container. + std::vector> models_; + + // Holds the AOTInductorModel instances available for inference. + std::vector available_models_; + + // Holds the AOTInductorModel instances that have started running + // inference and can be placed onto available_models_ upon their + // completion. + std::deque pending_models_; + + // Protects available_models_ and pending_models_. + std::mutex models_mutex_; + + // Notified whenever a model is placed onto pending_models_. + std::condition_variable pending_models_available_; + + AOTInductorModel* get_available_model() { + std::unique_lock lk(models_mutex_); + if (available_models_.empty()) { + reclaim_finished_models(lk); + } + auto* result = available_models_.back(); + available_models_.pop_back(); + return result; + } + + // This mutex is used to protect execution of model. + // We acquire the mutex in shared mode if we allow concurrent execution. + // We acquire the mutex in unique mode when we want exclusive access of the + // model. One such case is when we want to do a weight swapping. We want to + // make sure no one is executing the model. + std::shared_mutex model_exec_mutex_; + + RAIIDataPtr allocate_constant_blob() { +#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS) + return RAII_gpuMalloc(blob_size_); +#else + return RAII_cpuMalloc(blob_size_); +#endif // USE_CUDA + } + + void* get_constant_blob_ptr(bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + if (!constant_blob_) { + constant_blob_ = allocate_constant_blob(); + } + return constant_blob_.get(); + } else { + if (!constant_blob_secondary_) { + constant_blob_secondary_ = allocate_constant_blob(); + } + return constant_blob_secondary_.get(); + } + } + + std::shared_ptr get_constants_map(bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constants_map_; + } else { + if (!constants_map_secondary_) { + constants_map_secondary_ = std::make_shared(); + } + return constants_map_secondary_; + } + } + + std::shared_ptr> get_constants_array( + bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constants_array_; + } else { + if (!constants_array_secondary_) { + constants_array_secondary_ = + std::make_shared>( + models_[0]->num_constants()); + } + return constants_array_secondary_; + } + } + + void reclaim_finished_models(std::unique_lock& lk) { +#ifdef __aarch64__ + // push finished model instances to the end of pending_models_ + auto it = std::partition( + pending_models_.begin(), + pending_models_.end(), + [](AOTInductorModel* m) { return !m->is_finished(); }); +#else + // push finished model instances to the end of pending_models_ + auto it = std::stable_partition( + pending_models_.begin(), + pending_models_.end(), + [](AOTInductorModel* m) { return !m->is_finished(); }); +#endif + + if (it != pending_models_.end()) { + // We have finished model instances that can be pushed into + // available_models_ so that we don't have to be blocked on waiting + // the pending_models_available_ condition. + available_models_.insert( + available_models_.end(), it, pending_models_.end()); + pending_models_.erase(it, pending_models_.end()); + return; + } + + pending_models_available_.wait( + lk, [this]() { return !pending_models_.empty(); }); + // Let's make the schedule simple first. We always wait on the first + // pending_models_ to be complete. + auto* model = pending_models_.front(); + pending_models_.pop_front(); + lk.unlock(); + try { + model->wait_for_completion(); + } catch (...) { + lk.lock(); + available_models_.push_back(model); + throw; + } + lk.lock(); + available_models_.push_back(model); + } +}; + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..18e0b80589622f4a664aa56ad59007bac8072df8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +template +inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) { + throw std::runtime_error("Unsupported scalar_to_tensor_handle"); +} + +// Specialize for supported C++ primitive types +#define AOTI_RUNTIME_SCALAR_TO_TENSOR(dtype, ctype) \ + template <> \ + inline RAIIAtenTensorHandle scalar_to_tensor_handle(ctype value) { \ + AtenTensorHandle tensor_handle; \ + AOTI_TORCH_ERROR_CODE_CHECK( \ + aoti_torch_scalar_to_tensor_##dtype(value, &tensor_handle)); \ + return RAIIAtenTensorHandle(tensor_handle); \ + } + +AOTI_RUNTIME_SCALAR_TO_TENSOR(float32, float) +AOTI_RUNTIME_SCALAR_TO_TENSOR(float64, double) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint8, uint8_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint16, uint16_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint32, uint32_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint64, uint64_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int8, int8_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int16, int16_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int32, int32_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int64, int64_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(bool, bool) +AOTI_RUNTIME_SCALAR_TO_TENSOR(complex64, c10::complex) +AOTI_RUNTIME_SCALAR_TO_TENSOR(complex128, c10::complex) +#undef AOTI_RUNTIME_SCALAR_TO_TENSOR + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h new file mode 100644 index 0000000000000000000000000000000000000000..9745f69ccf4f1ecd88a261df97ed09d2b068cdd5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h @@ -0,0 +1,169 @@ +// NOLINT +#pragma once +#ifdef USE_XPU +#include +#include +#include +#include +#include +#include + +#define ZE_CHECK(status) \ + { \ + if (status != ZE_RESULT_SUCCESS) { \ + std::stringstream ss; \ + ss << "L0 runtime error: " << std::hex << std::uppercase << status; \ + throw std::runtime_error(ss.str()); \ + } \ + } + +static ze_module_handle_t _createModule( + const uint8_t* binaryPtr, + size_t binarySize) { + sycl::device& syclDevice = + c10::xpu::get_raw_device(c10::xpu::current_device()); + auto& syclContext = c10::xpu::get_device_context(); + auto device = + sycl::get_native(syclDevice); + auto context = + sycl::get_native(syclContext); + + const char* buildFlags = ""; + const ze_module_format_t format = ZE_MODULE_FORMAT_IL_SPIRV; + ze_module_desc_t moduleDescription = {}; + moduleDescription.stype = ZE_STRUCTURE_TYPE_MODULE_DESC; + moduleDescription.format = format; + moduleDescription.inputSize = binarySize; + moduleDescription.pInputModule = (uint8_t*)binaryPtr; + moduleDescription.pBuildFlags = buildFlags; + ze_module_build_log_handle_t buildLog = nullptr; + ze_module_handle_t module = nullptr; + auto error_no = ZE_RESULT_SUCCESS; + error_no = + zeModuleCreate(context, device, &moduleDescription, &module, &buildLog); + + if (error_no != ZE_RESULT_SUCCESS) { + size_t szLog = 0; + ZE_CHECK(zeModuleBuildLogGetString(buildLog, &szLog, nullptr)); + char* strLog = (char*)malloc(szLog); + ZE_CHECK(zeModuleBuildLogGetString(buildLog, &szLog, strLog)); + std::cerr << "L0 build module failed. Log: " << strLog << std::endl; + free(strLog); + } + if (buildLog) { + ZE_CHECK(zeModuleBuildLogDestroy(buildLog)); + } + ZE_CHECK(error_no); + return module; +} + +static std::unique_ptr _createKernel( + ze_module_handle_t module, + const char* kernelName) { + assert(module); + assert(kernelName); + ze_kernel_handle_t kernel = nullptr; + ze_kernel_desc_t kernelDescription = {}; + kernelDescription.stype = ZE_STRUCTURE_TYPE_KERNEL_DESC; + kernelDescription.pNext = nullptr; + kernelDescription.flags = ZE_KERNEL_FLAG_FORCE_RESIDENCY; + kernelDescription.pKernelName = kernelName; + ZE_CHECK(zeKernelCreate(module, &kernelDescription, &kernel)); + + auto& syclContext = c10::xpu::get_device_context(); + auto mod = sycl::make_kernel_bundle< + sycl::backend::ext_oneapi_level_zero, + sycl::bundle_state::executable>( + {module, sycl::ext::oneapi::level_zero::ownership::transfer}, + syclContext); + auto fun = sycl::make_kernel( + {mod, kernel, sycl::ext::oneapi::level_zero::ownership::transfer}, + syclContext); + return std::make_unique(fun); +} + +// GPU Cpp Wrapper API +[[maybe_unused]] static std::unique_ptr loadKernel( + std::string filePath, + const std::string& funcName, + uint32_t sharedMemBytes, + const std::optional& binDir = std::nullopt) { + if (binDir) { + std::filesystem::path p1{*binDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + std::ifstream IFS(filePath.c_str(), std::ios::binary); + std::ostringstream OSS; + OSS << IFS.rdbuf(); + std::string data(OSS.str()); + + auto mod = _createModule( + reinterpret_cast(data.c_str()), data.size()); + + return _createKernel(mod, funcName.c_str()); +} + +// GPU Cpp Wrapper API +[[maybe_unused]] static std::unique_ptr loadKernel( + const void* start, + const void* end, + const std::string& funcName, + uint32_t sharedMemBytes) { + size_t size = reinterpret_cast(end) - + reinterpret_cast(start); + + auto mod = _createModule(reinterpret_cast(start), size); + + return _createKernel(mod, funcName.c_str()); +} + +// GPU Cpp Wrapper API +[[maybe_unused]] static void launchKernel( + std::unique_ptr& kernelPtr, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemory, + void** params, + sycl::queue* queuePtr) { + std::string kernelName = + kernelPtr->get_info(); + // Currently threadsPerWarp is hard code to 32 from torch.compile to triton + // stack. + int threadsPerWarp = 32; + uint32_t numParams = kernelPtr->get_info(); + size_t globalRangeX = gridX * threadsPerWarp * numWarps; + size_t globalRangeY = gridY; + size_t globalRangeZ = gridZ; + size_t localRangeX = numWarps * threadsPerWarp; + size_t localRangeY = 1; + size_t localRangeZ = 1; + sycl::range<3> globalRange(globalRangeZ, globalRangeY, globalRangeX); + sycl::range<3> localRange(localRangeZ, localRangeY, localRangeX); + sycl::nd_range<3> parallelWorkSize(globalRange, localRange); + if (sharedMemory) { + // numParams from sycl info = user provided args + sharedMemroyBuffer + numParams -= 1; + } + // Submit the imported kernel. + auto cgf = [&](sycl::handler& cgh) { + for (uint32_t i = 0; i < numParams; ++i) { + cgh.set_arg(i, *(static_cast(params[i]))); + } + + if (sharedMemory > 0) { + constexpr int dimensions = 1; + using share_mem_t = sycl::local_accessor; + share_mem_t localBuffer = share_mem_t(sharedMemory, cgh); + cgh.set_arg(numParams, localBuffer); + cgh.parallel_for(parallelWorkSize, *kernelPtr); + } else { + cgh.parallel_for(parallelWorkSize, *kernelPtr); + } + }; + auto event = queuePtr->submit(cgf); +} +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/thread_local.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/thread_local.h new file mode 100644 index 0000000000000000000000000000000000000000..fd931c95626e4c3b6eff2f86034c003f1e8dc098 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/thread_local.h @@ -0,0 +1,160 @@ +#pragma once + +#include + +namespace torch::aot_inductor { + +template +struct ThreadLocalCachedOutputTensor; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {} + void copy_data_from(const RAIIAtenTensorHandle& handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {} + void copy_data_from(const AtenTensorHandle& handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {} + void copy_data_from(const ConstantHandle& handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +template +struct ThreadLocalCachedOutputTensor> { + explicit ThreadLocalCachedOutputTensor(const ArrayRefTensor& t) { + realloc(t); + } + + void copy_data_from(const ArrayRefTensor& t) { + if (t.numel() > capacity_) { + realloc(t); + } + std::copy(t.data(), t.data() + t.numel(), storage_.get()); + } + + AtenTensorHandle tensor() const { + return tensor_.get(); + } + + private: + void realloc(const ArrayRefTensor& t) { + capacity_ = t.numel(); + // NOLINTNEXTLINE(*arrays*) + storage_ = std::make_unique(t.numel()); + AtenTensorHandle handle = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( + storage_.get(), + t.sizes().size(), + t.sizes().data(), + t.strides().data(), + 0, + aoti_torch_dtype>(), + t.device_type(), + t.device_idx(), + &handle)); + tensor_ = handle; + } + + // NOLINTNEXTLINE(*arrays*) + std::unique_ptr storage_; + int64_t capacity_ = 0; + RAIIAtenTensorHandle tensor_; +}; + +template +struct ThreadLocalCachedOutputArray; + +// Just needs to compile, doesn't need to do anything. +template <> +struct ThreadLocalCachedOutputArray { + explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) { + throw std::runtime_error("can't happen"); + } + + // Not supported yet! We would need to put contiguous() or + // expect_contiguous() into the ABI. + void copy_data_from(const RAIIAtenTensorHandle&) { + throw std::runtime_error("can't happen"); + } + + template + ArrayRefTensor arrayref_tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +// Just needs to compile, doesn't need to do anything. +template <> +struct ThreadLocalCachedOutputArray { + explicit ThreadLocalCachedOutputArray(const ConstantHandle&) { + throw std::runtime_error("can't happen"); + } + + // Not supported yet! We would need to put contiguous() or + // expect_contiguous() into the ABI. + void copy_data_from(const ConstantHandle&) { + throw std::runtime_error("can't happen"); + } + + template + ArrayRefTensor arrayref_tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +template +struct ThreadLocalCachedOutputArray> { + explicit ThreadLocalCachedOutputArray(const ArrayRefTensor& t) {} + + template < + typename U, + std::enable_if_t< + std::is_same_v, std::remove_const_t>, + bool> = true> + ArrayRefTensor arrayref_tensor() const { + return tensor_; + } + + void copy_data_from(const ArrayRefTensor& t) { + if (t.numel() > capacity_) { + capacity_ = t.numel(); + // NOLINTNEXTLINE(*arrays*) + storage_ = std::make_unique(capacity_); + } + std::copy(t.data(), t.data() + t.numel(), storage_.get()); + tensor_ = t; + tensor_.set_arrayref(MiniArrayRef(storage_.get(), t.numel())); + } + + private: + // NOLINTNEXTLINE(*arrays*) + std::unique_ptr storage_; + uint32_t capacity_ = 0; + ArrayRefTensor tensor_; +}; + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b6c009805c71d1292b4d440bdb89be0c4e5cf195 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils.h @@ -0,0 +1,372 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +#if defined(__GNUC__) || defined(__clang__) +#define AOTI_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define AOTI_NOINLINE __declspec(noinline) +#else +#define AOTI_NOINLINE +#endif + +AOTI_NOINLINE static void throw_exception( + const char* call, + const char* file, + int64_t line) { + std::stringstream ss; + ss << call << " API call failed at " << file << ", line " << line; + throw std::runtime_error(ss.str()); +} + +#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_TORCH_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ + } + +using AOTIRuntimeError = int32_t; +#define AOTI_RUNTIME_SUCCESS 0 +#define AOTI_RUNTIME_FAILURE 1 + +#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_RUNTIME_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ + } + +namespace torch::aot_inductor { + +using DeleterFnPtr = void (*)(void*); + +inline void noop_deleter(void*) {} + +inline void delete_tensor_object(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_delete_tensor_object(reinterpret_cast(ptr))); +} + +// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI +class RAIIAtenTensorHandle { + public: + RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {} + RAIIAtenTensorHandle(const RAIIAtenTensorHandle& other) = delete; + RAIIAtenTensorHandle& operator=(const RAIIAtenTensorHandle& other) = delete; + + // Steal the ownership from another RAIIAtenTensorHandle using std::move + RAIIAtenTensorHandle(RAIIAtenTensorHandle&& other) = default; + RAIIAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) = default; + + // Steal the ownership from raw AtenTensorHandle + RAIIAtenTensorHandle(AtenTensorHandle handle) + : handle_(handle, delete_tensor_object) {} + + ~RAIIAtenTensorHandle() { + handle_.reset(); + } + + // Return a raw AtenTensorHandle to be used by aoti_torch functions + // Note: this function does NOT transfer the ownership of the handle + operator AtenTensorHandle() const { + return handle_.get(); + } + + AtenTensorHandle release() { + return handle_.release(); + } + + AtenTensorHandle get() const { + return handle_.get(); + } + + void reset() { + handle_.reset(); + } + + int64_t size(int64_t d) { + int64_t size = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size)); + return size; + } + + int64_t stride(int64_t d) { + int64_t stride = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_stride(handle_.get(), d, &stride)); + return stride; + } + + int64_t storage_offset() { + int64_t storage_offset = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(handle_.get(), &storage_offset)); + return storage_offset; + } + + void* data_ptr() const { + void* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_data_ptr(handle_.get(), &result)); + return result; + } + + int64_t* sizes() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_.get(), &result)); + return result; + } + + int64_t* strides() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_.get(), &result)); + return result; + } + + private: + std::unique_ptr handle_; +}; + +class MaybeOwningAtenTensorHandle { + public: + MaybeOwningAtenTensorHandle() : handle_(nullptr), raii_handle_() {} + // We skip copy constructor as MaybeOwningAtenTensorHandle might be RAII which + // makes it undefined. + MaybeOwningAtenTensorHandle(const MaybeOwningAtenTensorHandle& other) = + delete; + MaybeOwningAtenTensorHandle& operator=( + const MaybeOwningAtenTensorHandle& other) = delete; + + // Move constructor and move assignment operator + MaybeOwningAtenTensorHandle(MaybeOwningAtenTensorHandle&& other) = default; + MaybeOwningAtenTensorHandle& operator=(MaybeOwningAtenTensorHandle&& other) = + default; + + // Steal the ownership from another RAIIAtenTensorHandle using std::move + MaybeOwningAtenTensorHandle(RAIIAtenTensorHandle&& other) + : raii_handle_(std::move(other)) { + handle_ = raii_handle_.get(); + } + MaybeOwningAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) { + raii_handle_ = std::move(other); + handle_ = raii_handle_.get(); + return *this; + } + + // By default, steal the ownership from raw AtenTensorHandle + MaybeOwningAtenTensorHandle(AtenTensorHandle handle) : raii_handle_(handle) { + handle_ = raii_handle_.get(); + } + + // If user_managed is true, we do not steal the ownership. + MaybeOwningAtenTensorHandle(AtenTensorHandle handle, bool user_managed) { + if (user_managed) { + aoti_torch_new_tensor_handle(handle, &handle_); + } else { + raii_handle_ = RAIIAtenTensorHandle(handle); + handle_ = raii_handle_.get(); + } + } + + ~MaybeOwningAtenTensorHandle() { + // This is no-op if we don't hold raii_handle with the + // MaybeOwningAtenTensorHandle. + raii_handle_.reset(); + } + + // Return a raw AtenTensorHandle to be used by aoti_torch functions + // Note: this function does NOT transfer the ownership of the handle + operator AtenTensorHandle() const { + return handle_; + } + + AtenTensorHandle release() { + if (raii_handle_) { + return raii_handle_.release(); + } else { + AtenTensorHandle handle = handle_; + handle_ = nullptr; + return handle; + } + } + + AtenTensorHandle get() const { + return handle_; + } + + void reset() { + handle_ = nullptr; + raii_handle_.reset(); + } + + int64_t size(int64_t d) { + int64_t size = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_, d, &size)); + return size; + } + + int64_t stride(int64_t d) { + int64_t stride = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(handle_, d, &stride)); + return stride; + } + + int64_t storage_offset() { + int64_t storage_offset = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(handle_, &storage_offset)); + return storage_offset; + } + + void* data_ptr() const { + void* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &result)); + return result; + } + + int64_t* sizes() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_, &result)); + return result; + } + + int64_t* strides() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_, &result)); + return result; + } + + private: + // handle_ is the underlying AtenTensorHandle of raii_handle_ if raii_handle_ + // exists. Otherwise it would just be the AtenTensorHandle passed in by users. + AtenTensorHandle handle_; + RAIIAtenTensorHandle raii_handle_; +}; + +// Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle +inline std::vector steal_from_raw_handles_to_raii_handles( + AtenTensorHandle* handles, + size_t size) { + std::vector result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + result.emplace_back(handles[i]); + handles[i] = nullptr; + } + return result; +} + +inline AtenTensorHandle reinterpret_tensor_wrapper( + AtenTensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset) { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor( + self, ndim, sizes_ptr, strides_ptr, storage_offset, &result)); + return result; +} + +inline void* get_data_ptr_wrapper(AtenTensorHandle tensor) { + void* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(tensor, &result)); + return result; +} + +inline AtenTensorHandle unwrap_raii_handle_if_needed( + const RAIIAtenTensorHandle& handle) { + return handle.get(); +} + +inline RAIIAtenTensorHandle wrap_with_raii_handle_if_needed( + AtenTensorHandle handle) { + return RAIIAtenTensorHandle(handle); +} + +class ConstantHandle { + public: + ConstantHandle() = default; + + explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_)); + } + + operator AtenTensorHandle() const { + return handle_; + } + + AtenTensorHandle tensor() const { + return handle_; + } + + AtenTensorHandle get() const { + return handle_; + } + + void* data_ptr() const { + return data_; + } + + int64_t* sizes() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_, &result)); + return result; + } + + int64_t* strides() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_, &result)); + return result; + } + + private: + AtenTensorHandle handle_{}; + void* data_ = nullptr; +}; + +inline void* get_data_ptr_wrapper(const ConstantHandle& constant) { + return constant.data_ptr(); +} + +inline const ConstantHandle& unwrap_raii_handle_if_needed( + const ConstantHandle& handle) { + return handle; +} + +// Shouldn't be called. +inline AtenTensorHandle wrap_with_raii_handle_if_needed( + const ConstantHandle& handle) = delete; + +// DANGEROUS. Do not call unless you explicitly intend to get a reference to a +// temporary value, which will expire at the end of the current expression. +// This should only be called in cases where the C-shim API expects an optional +// input argument (passed by pointer), and a temporary needs to be passed to it. +template +T& temporary_reference(T&& t) { + return t; +} + +#define CACHE_TORCH_DTYPE(typename) \ + static auto cached_torch_dtype_##typename = aoti_torch_dtype_##typename() + +#define CACHE_TORCH_DEVICE(device) \ + static auto cached_torch_device_type_##device = \ + aoti_torch_device_type_##device() + +#define CACHE_TORCH_LAYOUT(layout) \ + static auto cached_torch_layout_##layout = aoti_torch_layout_##layout() + +#define CACHE_TORCH_MEMORY_FORMAT(format) \ + static auto cached_torch_memory_format_##format = \ + aoti_torch_memory_format_##format() + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils_cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..a4f1706ec7fb6a7d0addf3725004c203a7c4a556 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils_cuda.h @@ -0,0 +1,63 @@ +#pragma once + +#ifdef USE_CUDA +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +#include +#include +#ifndef USE_ROCM +#include +#include +#include +#endif + +namespace torch::aot_inductor { + +inline void delete_cuda_guard(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_delete_cuda_guard(reinterpret_cast(ptr))); +} + +inline void delete_cuda_stream_guard(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard( + reinterpret_cast(ptr))); +} + +class AOTICudaGuard { + public: + AOTICudaGuard(int32_t device_index) : guard_(nullptr, delete_cuda_guard) { + CUDAGuardHandle ptr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_create_cuda_guard(device_index, &ptr)); + guard_.reset(ptr); + } + + void set_index(int32_t device_index) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_cuda_guard_set_index(guard_.get(), device_index)); + } + + private: + std::unique_ptr guard_; +}; + +class AOTICudaStreamGuard { + public: + AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) + : guard_(nullptr, delete_cuda_stream_guard) { + CUDAStreamGuardHandle ptr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr)); + guard_.reset(ptr); + } + + private: + std::unique_ptr guard_; +}; + +} // namespace torch::aot_inductor +#endif // USE_CUDA diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils_xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils_xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..30f7e717b715c05fafd30268f34b5d9aa3047343 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_runtime/utils_xpu.h @@ -0,0 +1,56 @@ +#pragma once + +#ifdef USE_XPU +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include +#include + +namespace torch::aot_inductor { + +inline void delete_xpu_guard(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_delete_xpu_guard(reinterpret_cast(ptr))); +} + +inline void delete_xpu_stream_guard(void* ptr) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_xpu_stream_guard( + reinterpret_cast(ptr))); +} + +class AOTIXpuGuard { + public: + AOTIXpuGuard(int32_t device_index) : guard_(nullptr, delete_xpu_guard) { + XPUGuardHandle ptr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_create_xpu_guard(device_index, &ptr)); + guard_.reset(ptr); + } + + void set_index(int32_t device_index) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_xpu_guard_set_index(guard_.get(), device_index)); + } + + private: + std::unique_ptr guard_; +}; + +class AOTIXpuStreamGuard { + public: + AOTIXpuStreamGuard(void* stream, int32_t device_index) + : guard_(nullptr, delete_xpu_stream_guard) { + XPUStreamGuardHandle ptr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_create_xpu_stream_guard(stream, device_index, &ptr)); + guard_.reset(ptr); + } + + private: + std::unique_ptr guard_; +}; + +} // namespace torch::aot_inductor +#endif // USE_XPU diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim.h new file mode 100644 index 0000000000000000000000000000000000000000..6a23c9d465c7f087d3d63a0335d369a080ac052a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim.h @@ -0,0 +1,804 @@ +#ifndef AOTI_TORCH_SHIM +#define AOTI_TORCH_SHIM + +#include +#include + +// This header defines a stable C API for certain ATen functionality in +// libtorch. The AOTInductor compiled model.so will only refer to this header +// instead of other headers from aten/c10, which means it will NOT be able to +// directly use any data structures or call functions from libtorch. +// +// What problems are we trying to solve here? Direct use of aten/c10 APIs +// means use of C++ APIs on a library that doesn't have any ABI compatibility +// guarantees. However, we want model.so to remain usable across updates +// to the PyTorch C++ libraries, which requires a stable ABI. By introducing +// a C shim layer, we can minimize the surface that will cause breakage. The +// corresponding software stack can be illustrated as follows: +// +// |--------------------------------| +// | inference service code | +// |--------------------------------| +// | model.so | +// |--------------|-----------------| +// | | +// | libtorch.so | +// |--------------------------------| +// +// The general guidelines for the C API: +// +// - No exceptions, return an explicit error code to be checked at call site +// - Only pointers (AtenTensorHandle counts), integers and floats in headers +// +// If you want to make changes to this header, you MUST MAINTAIN ABI +// compatibility. Typically, this means you will have to add a _v2 version +// of a function that you, e.g., want to add a new function parameter to, and +// maintain the old and new versions of the APIs until all old model.so +// go out of use. + +#ifdef __GNUC__ +#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) +#else // !__GNUC__ +#ifdef _WIN32 +// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead +// to symbol clashes at link time if libtorch is included in a DLL and binary +// that depends on the DLL. As a short term fix, we don't export the symbols. +// In the long term, this will need to be addressed when Windows is supported. +#ifdef OVRSOURCE +// Do not export AOTI on Windows for internal builds +#define AOTI_TORCH_EXPORT +#else /* OVRSOURCE */ +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_TORCH_EXPORT __declspec(dllexport) +#else +#define AOTI_TORCH_EXPORT __declspec(dllimport) +#endif +#endif /* OVRSOURCE */ +#else // !_WIN32 +#define AOTI_TORCH_EXPORT +#endif // _WIN32 +#endif // __GNUC__ + +// The following files are implemented in a header-only way and are guarded by +// test/cpp/aoti_abi_check +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// AtenTensorHandle represents an abstract notion of Tensor that can be passed +// between model.so and libtorch.so. The contents of the structure itself +// are private; model.so is not allowed to access any fields directly, it must +// go through functions defined in this ABI. Under the hood, this is +// represented as at::Tensor*, but we reserve the right to change this (and in +// fact, we probably should change it to at::TensorImpl* at least). +// +// An AtenTensorHandle can be owning (please check the API reference for exact +// ownership/borrow semantics). If you have an owning AtenTensorHandle +// in model.so, you are obligated to aoti_torch_delete_tensor_object when you +// are done. You can use the helper C++ class RAIIAtenTensorHandle +// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style +// (note that RAIIAtenTensorHandle is private to model.so, and never crosses +// the ABI boundary.) +struct AtenTensorOpaque; +using AtenTensorHandle = AtenTensorOpaque*; + +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque*; + +struct AOTIProxyExecutorOpaque; +using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; + +using AOTITorchError = int32_t; +#define AOTI_TORCH_SUCCESS 0 +#define AOTI_TORCH_FAILURE 1 + +// Getter functions for retrieving various constants from the runtime, that +// can subsequently be passed to other aoti_* functions. By hiding these +// behind functions, the precise value of device/dtype is NOT part of the +// ABI contract. (In practice, aten/c10 is pretty good about not renumbering +// these, so we probably could later switch to having these in the ABI, if +// desired for perf reasons.) +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_meta(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_xpu(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_mps(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint8(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128(); +AOTI_TORCH_EXPORT size_t aoti_torch_dtype_element_size(int32_t dtype); + +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_coo(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_csr(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_csc(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_bsr(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_bsc(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout__mkldnn(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_jagged(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_contiguous_format(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_channels_last(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_channels_last_3d(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_preserve_format(); + +// Get TORCH_ABI_VERSION of the built libtorch.so +AOTI_TORCH_EXPORT uint64_t aoti_torch_abi_version(); + +// Functions for converting a single-element tensor to a scalar value +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float16(AtenTensorHandle tensor, c10::Half* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float32(AtenTensorHandle tensor, float* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float64(AtenTensorHandle tensor, double* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_uint8(AtenTensorHandle tensor, uint8_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_uint16(AtenTensorHandle tensor, uint16_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_uint32(AtenTensorHandle tensor, uint32_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_uint64(AtenTensorHandle tensor, uint64_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_int8(AtenTensorHandle tensor, int8_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_int16(AtenTensorHandle tensor, int16_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_int32(AtenTensorHandle tensor, int32_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_int64(AtenTensorHandle tensor, int64_t* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_bool(AtenTensorHandle tensor, bool* ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16* ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64( + AtenTensorHandle tensor, + c10::complex* ret_value); + +// Functions for wrapping a scalar value to a single-element tensor +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32( + float value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float64( + double value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint8( + uint8_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint16( + uint16_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint32( + uint32_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint64( + uint64_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int8( + int8_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int16( + int16_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int32( + int32_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int64( + int64_t value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_scalar_to_tensor_bool(bool value, AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex64( + c10::complex value, + AtenTensorHandle* ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex128( + c10::complex value, + AtenTensorHandle* ret_new_tensor); + +AOTI_TORCH_EXPORT bool aoti_torch_grad_mode_is_enabled(); +AOTI_TORCH_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled); + +// Free the tensor object +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_tensor_object(AtenTensorHandle tensor); + +// Get a pointer to the underlying storage data +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr( + AtenTensorHandle tensor, + void** ret_data_ptr // returns borrowed reference +); + +// Get the nbytes of the underlying storage +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_storage_size(AtenTensorHandle tensor, int64_t* ret_size); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_dim(AtenTensorHandle tensor, int64_t* ret_dim); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_numel(AtenTensorHandle tensor, int64_t* ret_numel); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_storage_numel(AtenTensorHandle tensor, int64_t* ret_numel); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_sizes( + AtenTensorHandle tensor, + int64_t** ret_sizes // returns borrowed reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_size(AtenTensorHandle tensor, int64_t d, int64_t* ret_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_strides( + AtenTensorHandle tensor, + int64_t** ret_strides // returns borrowed reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_stride(AtenTensorHandle tensor, int64_t d, int64_t* ret_stride); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_dtype(AtenTensorHandle tensor, int32_t* ret_dtype); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( + AtenTensorHandle tensor, + int64_t* ret_storage_offset); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( + AtenTensorHandle orig_handle, + AtenTensorHandle* new_handle); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__alloc_from_pool( + AtenTensorHandle self, + int64_t offset_bytes, + int32_t dtype, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + AtenTensorHandle* ret_new_tensor); + +// This function will create a new tensor object and its pointer is returned +// through *out. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( + AtenTensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AtenTensorHandle* ret_new_tensor // returns new reference +); + +// This function will create a new tensor object and its pointer is returned +// through *out. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret_new_tensor // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_as_strided( + AtenTensorHandle self, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + AtenTensorHandle* ret); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AtenTensorHandle* ret, // returns new reference + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( + AtenTensorHandle weight, + AtenTensorHandle indices, + AtenTensorHandle offsets, + int32_t scale_grad_by_freq, + int32_t mode, + int32_t sparse, + AtenTensorHandle per_sample_weights, // optional argument + int32_t include_last_offset, + int32_t padding_idx, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( + AtenTensorHandle self, + const int64_t* dim_ptr, + int64_t dim_size, + int64_t normalization, + int32_t forward, + AtenTensorHandle* ret // returns new reference +); + +// This version is deprecated. We will remove it later +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + int is_causal, + int return_debug_mask, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_efficient_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + AtenTensorHandle attn_bias, // optional argument + int compute_log_sumexp, + double dropout_p, + int is_causal, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle bias, + int32_t* out_dtype, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle scale_result, + int8_t use_fast_accum, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle bias, + AtenTensorHandle scale_result, + int32_t* out_dtype, + int8_t use_fast_accum, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( + AtenTensorHandle input, + AtenTensorHandle weight, + AtenTensorHandle bias, // optional argument + const int64_t* stride_ptr, + int64_t stride_size, + const int64_t* padding_ptr, + int64_t padding_size, + const int64_t* dilation_ptr, + int64_t dilation_size, + int transposed, + const int64_t* output_padding_ptr, + int64_t output_padding_size, + int64_t groups, + AtenTensorHandle* ret // returns new reference +); + +// This function will create a new uninitialized tensor object +// and its pointer is returned through *ret. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret); + +// WARNING: This will be deprecated. Use aoti_torch_copy_ instead. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst); + +// Make the tensor referred to by dst an alias for the tensor referred +// to by src. The two tensors must still be deleted with +// aoti_torch_delete_tensor separately (or not) as before the call. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst); + +// Make a shallow copy of the tensor referred to by src and assign +// it to the handle in the ret_dst. This is similar to the above +// aoti_torch_assign_tensors function, but creates and sets the +// ret_dst from within. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(AtenTensorHandle src, AtenTensorHandle* ret_dst); + +// This function will create a new tensor object and its pointer is returned +// through *ret. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_clone_preserve_strides(AtenTensorHandle self, AtenTensorHandle* ret); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat1, + AtenTensorHandle mat2, + float beta, + float alpha); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_( + AtenTensorHandle self, + AtenTensorHandle src, + int32_t non_blocking); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out( + AtenTensorHandle out, + AtenTensorHandle a, + AtenTensorHandle b, + AtenTensorHandle c, + AtenTensorHandle d); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16( + AtenTensorHandle weight, + AtenTensorHandle* out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( + AtenTensorHandle weight, + AtenTensorHandle weight_scale, + AtenTensorHandle weight_zero_point, + AtenTensorHandle bias, + AtenTensorHandle* out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( + AtenTensorHandle input, + AtenTensorHandle weight, + AtenTensorHandle bias, + int64_t out_channel, + AtenTensorHandle* out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu__wrapped_quantized_linear_prepacked( + AtenTensorHandle input, + AtenTensorHandle input_scale, + AtenTensorHandle input_zero_point, + AtenTensorHandle weight, + AtenTensorHandle out_scale, + AtenTensorHandle out_zeropoint, + int64_t out_channel, + AtenTensorHandle* out); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( + AtenTensorHandle repeats, + int64_t* output_size, + AtenTensorHandle* out); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out( + AtenTensorHandle out, + AtenTensorHandle self, + int64_t dim, + AtenTensorHandle index, + AtenTensorHandle src); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_reduce_out( + AtenTensorHandle out, + AtenTensorHandle self, + int64_t dim, + AtenTensorHandle index, + AtenTensorHandle src, + const char* reduce, + int32_t include_self); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out( + AtenTensorHandle out, + AtenTensorHandle self, + const AtenTensorHandle* indices, + const uint32_t num_indices, + const AtenTensorHandle values, + bool accumulate); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( + AtenTensorHandle self, + AtenTensorHandle* ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( + AtenTensorHandle self, + int32_t dtype, + AtenTensorHandle* ret // returns new reference +); + +AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( + AtenTensorHandle self, + const char* msg); + +// When AOTI debug printer option is enabled, this function will be invoked to +// torch pickle save the intermediate tensor for debugging purpose. +AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( + AtenTensorHandle self, + const char* tensor_name, + const char* launch_prefix, + const char* kernel_name); + +// helpers for converting between StableIValue and actual IValues +using StableIValue = uint64_t; + +class TorchLibraryOpaque; +using TorchLibraryHandle = TorchLibraryOpaque*; + +// stable corollary to torch::Library constructor with Kind::IMPL +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_impl( + const char* ns, + const char* k, + const char* file, + uint32_t line, + TorchLibraryHandle* ret_new_torch_lib); + +// stable corollary to torch::Library constructor with Kind::DEF +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_def( + const char* ns, + const char* file, + uint32_t line, + TorchLibraryHandle* ret_new_torch_lib); + +// stable corollary to torch::Library constructor with Kind::FRAGMENT +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_fragment( + const char* ns, + const char* file, + uint32_t line, + TorchLibraryHandle* ret_new_torch_lib); + +// stable corollary to torch::Library method m.impl(), should be +// called from StableLibrary +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl( + TorchLibraryHandle self, + const char* name, + void (*fn)(StableIValue*, uint64_t, uint64_t)); + +// stable corollary to torch::Library method m.def(), should be +// called from StableLibrary +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_library_def(TorchLibraryHandle self, const char* schema); + +// the above stable constructors for torch::Library add Library objects +// to the heap. if you are calling those functions directly, please use +// this function to free the Library's memory. The more user friendly +// alternative is to use StableLibrary, which will free its handle upon +// destruction +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_library_object(TorchLibraryHandle tlh); + +// calls the op overload defined by a given opName, overloadName, and a +// stack of StableIValues. This call will populate any return values of the +// op into the stack in their StableIValue form, with ret0 at index 0, ret1 +// at index 1, and so on. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher( + const char* opName, + const char* overloadName, + StableIValue* stack); + +#ifdef USE_CUDA + +struct CUDAGuardOpaque; +using CUDAGuardHandle = CUDAGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_cuda_guard( + int32_t device_index, + CUDAGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_cuda_guard(CUDAGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cuda_guard_set_index(CUDAGuardHandle guard, int32_t device_index); + +struct CUDAStreamGuardOpaque; +using CUDAStreamGuardHandle = CUDAStreamGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_cuda_stream_guard( + void* stream, + int32_t device_index, + CUDAStreamGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); + +#endif // USE_CUDA + +// See `ProxyExecutor Design Note` in ir.py for more details +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function( + AOTIProxyExecutorHandle proxy_executor, + int extern_node_index, + int num_ints, + int64_t* flatten_int_args, + int num_tensors, + AtenTensorHandle* flatten_tensor_args); + +AOTI_TORCH_EXPORT void aoti_torch_check( + bool cond, + const char* func, + const char* file, + uint32_t line, + const char* msg); + +#ifdef STRIP_ERROR_MESSAGES +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check( \ + false, \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } +#else +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check( \ + false, \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ + } +#endif + +AOTI_TORCH_EXPORT void aoti_torch_warn( + const char* func, + const char* file, + uint32_t line, + const char* msg); + +#ifdef DISABLE_WARN +#define AOTI_TORCH_WARN(...) ((void)0); +#else +#define AOTI_TORCH_WARN(...) \ + aoti_torch_warn( \ + __func__, __FILE__, static_cast(__LINE__), #__VA_ARGS__); +#endif + +#ifdef __cplusplus +} // extern "C" + +template +int32_t aoti_torch_dtype() = delete; + +#define DEFINE_DTYPE_SPECIALIZATION(ctype, typename) \ + template <> \ + inline int32_t aoti_torch_dtype() { \ + return aoti_torch_dtype_##typename(); \ + } + +DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) +DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16) +DEFINE_DTYPE_SPECIALIZATION(c10::complex, complex64) +DEFINE_DTYPE_SPECIALIZATION(float, float32) +DEFINE_DTYPE_SPECIALIZATION(double, float64) +DEFINE_DTYPE_SPECIALIZATION(uint8_t, uint8) +DEFINE_DTYPE_SPECIALIZATION(int8_t, int8) +DEFINE_DTYPE_SPECIALIZATION(int16_t, int16) +DEFINE_DTYPE_SPECIALIZATION(int32_t, int32) +DEFINE_DTYPE_SPECIALIZATION(int64_t, int64) +DEFINE_DTYPE_SPECIALIZATION(bool, bool) + +#endif +#endif // AOTI_TORCH_SHIM diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_cpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..c7b713bf7f877e4ead616a823a152d51654a873f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_cpu.h @@ -0,0 +1,251 @@ +#ifndef AOTI_TORCH_SHIM_CPU +#define AOTI_TORCH_SHIM_CPU + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#if AT_MKLDNN_ENABLED() + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_mkldnn__convolution_pointwise_binary( + AtenTensorHandle X, + AtenTensorHandle other, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_mkldnn__convolution_pointwise_binary_( + AtenTensorHandle other, + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mkldnn__convolution_pointwise( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* attr, + const double** scalars, + int64_t scalars_len_, + const char** algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_mkldnn__convolution_transpose_pointwise( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* output_padding, + int64_t output_padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* attr, + const double** scalars, + int64_t scalars_len_, + const char** algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mkldnn_rnn_layer( + AtenTensorHandle input, + AtenTensorHandle weight0, + AtenTensorHandle weight1, + AtenTensorHandle weight2, + AtenTensorHandle weight3, + AtenTensorHandle hx_, + AtenTensorHandle cx_, + int32_t reverse, + const int64_t* batch_sizes, + int64_t batch_sizes_len_, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + int32_t has_biases, + int32_t bidirectional, + int32_t batch_first, + int32_t train, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1, + AtenTensorHandle* ret2, + AtenTensorHandle* ret3); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__linear_pointwise( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const char* attr, + const double** scalars, + int64_t scalars_len_, + const char** algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__linear_pointwise_binary( + AtenTensorHandle X, + AtenTensorHandle other, + AtenTensorHandle W, + AtenTensorHandle* B, + const char* attr, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qlinear_pointwise_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle* B, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + const char* post_op_name, + const double** post_op_args, + int64_t post_op_args_len_, + const char* post_op_algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu__qlinear_pointwise_binary_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle* other, + AtenTensorHandle* B, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + double other_scale, + int64_t other_zero_point, + const char* binary_post_op, + double binary_alpha, + const char* unary_post_op, + const double** unary_post_op_args, + int64_t unary_post_op_args_len_, + const char* unary_post_op_algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qconv_pointwise_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle* B, + const int64_t* stride, + int64_t stride_len_, + const int64_t* padding, + int64_t padding_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + const char* attr, + const double** post_op_args, + int64_t post_op_args_len_, + const char** algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu__qconv2d_pointwise_binary_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle accum, + AtenTensorHandle* B, + const int64_t* stride_args, + int64_t stride_len_, + const int64_t* padding_args, + int64_t padding_len_, + const int64_t* dilation_args, + int64_t dilation_len_, + int64_t groups, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + double accum_scale, + int64_t accum_zero_point, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0); + +#if AT_MKL_ENABLED() + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__mkl_linear( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle origin_W, + AtenTensorHandle* B, + int64_t prepack_batch_size, + AtenTensorHandle* ret0); + +#endif // AT_MKL_ENABLED + +#endif // AT_MKLDNN_ENABLED() + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( + AtenTensorHandle X, + AtenTensorHandle w, + AtenTensorHandle qGroupSize, + AtenTensorHandle qScaleAndZeros, + AtenTensorHandle* ret0); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // AOTI_TORCH_SHIM_CPU diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_mps.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_mps.h new file mode 100644 index 0000000000000000000000000000000000000000..bd86885de13ca82eff78f7c78f4fba8b824d7e99 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_mps.h @@ -0,0 +1,39 @@ +#ifndef AOTI_TORCH_SHIM_MPS +#define AOTI_TORCH_SHIM_MPS + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct AOTIMetalKernelFunctionOpaque; +using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AtenTensorHandle tensor); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_mps_malloc(void** buffer, size_t num_bytes); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_free(void* ptr); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // AOTI_TORCH_SHIM_MPS diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..408c99ca655f65077f22c7b53f7f710b38a0b0d8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/c/shim_xpu.h @@ -0,0 +1,116 @@ +#ifndef AOTI_TORCH_SHIM_XPU +#define AOTI_TORCH_SHIM_XPU + +#include +#include + +#ifdef USE_XPU +#ifdef __cplusplus +extern "C" { +#endif + +struct XPUGuardOpaque; +using XPUGuardHandle = XPUGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_xpu_guard( + int32_t device_index, + XPUGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_xpu_guard(XPUGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_xpu_guard_set_index(XPUGuardHandle guard, int32_t device_index); + +struct XPUStreamGuardOpaque; +using XPUStreamGuardHandle = XPUStreamGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_xpu_stream_guard( + void* stream, + int32_t device_index, + XPUStreamGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_xpu_stream_guard(XPUStreamGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_current_xpu_device(int32_t* device_index); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_set_current_xpu_device(const int32_t& device_index); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret); + +#if AT_MKLDNN_ENABLED() + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_xpu_mkldnn__convolution_pointwise_binary( + AtenTensorHandle X, + AtenTensorHandle other, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mkldnn__convolution_pointwise( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* attr, + const double** scalars, + int64_t scalars_len_, + const char** algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_xpu_mkldnn__convolution_pointwise_binary_( + AtenTensorHandle other, + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle* B, + const int64_t* padding, + int64_t padding_len_, + const int64_t* stride, + int64_t stride_len_, + const int64_t* dilation, + int64_t dilation_len_, + int64_t groups, + const char* binary_attr, + double* alpha, + const char** unary_attr, + const double** unary_scalars, + int64_t unary_scalars_len_, + const char** unary_algorithm, + AtenTensorHandle* ret0); + +#endif // AT_MKLDNN_ENABLED() +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // USE_XPU +#endif // AOTI_TORCH_SHIM_XPU diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..2aa09cb802ecdeff946893871386e02bef9ee700 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -0,0 +1,159 @@ + + +// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. +// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_matmul_4bit(AtenTensorHandle inp, AtenTensorHandle packed_weights, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_pack_4bit_weight(AtenTensorHandle weights, AtenTensorHandle scales_zeros, AtenTensorHandle* bias, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_add_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_angle(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cholesky_inverse(AtenTensorHandle self, int32_t upper, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fill__Scalar(AtenTensorHandle self, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_fractional_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_geqrf(AtenTensorHandle self, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_index_reduce(AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle source, const char* reduce, int32_t include_self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_lu_unpack(AtenTensorHandle LU_data, AtenTensorHandle LU_pivots, int32_t unpack_data, int32_t unpack_pivots, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_masked_select(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_median(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..e0607f984b3d0f7cfb2cc1c4ac8a7135f7956c63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -0,0 +1,165 @@ + + +// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. +// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__adaptive_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__cudnn_rnn(AtenTensorHandle input, const AtenTensorHandle* weight, int64_t weight_len_, int64_t weight_stride0, AtenTensorHandle* weight_buf, AtenTensorHandle hx, AtenTensorHandle* cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, int32_t batch_first, double dropout, int32_t train, int32_t bidirectional, const int64_t* batch_sizes, int64_t batch_sizes_len_, AtenTensorHandle* dropout_state, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle out, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, AtenTensorHandle logsumexp, double dropout_p, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, int64_t custom_mask_type, int32_t bias_requires_grad, double* scale, int64_t* num_splits_key, int64_t* window_size, int32_t shared_storage_dqdkdv, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficient_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* bias, AtenTensorHandle* cu_seqlens_q, AtenTensorHandle* cu_seqlens_k, int64_t* max_seqlen_q, int64_t* max_seqlen_k, double dropout_p, int64_t custom_mask_type, int32_t compute_log_sumexp, double* scale, AtenTensorHandle* seqlen_k, int64_t* window_size, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle rng_state, AtenTensorHandle unused, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_cudnn_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute_log_sumexp, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_cudnn_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, AtenTensorHandle attn_bias, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, int32_t compute_log_sumexp, double dropout_p, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTensorHandle input_gates, AtenTensorHandle hidden_gates, AtenTensorHandle cx, AtenTensorHandle* input_bias, AtenTensorHandle* hidden_bias, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_abs(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_add_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_angle(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cholesky_inverse(AtenTensorHandle self, int32_t upper, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fill__Scalar(AtenTensorHandle self, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle random_samples, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_fractional_max_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_geqrf(AtenTensorHandle self, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_index_reduce(AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle source, const char* reduce, int32_t include_self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_lu_unpack(AtenTensorHandle LU_data, AtenTensorHandle LU_pivots, int32_t unpack_data, int32_t unpack_pivots, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_masked_select(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_median(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mode(AtenTensorHandle self, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_ormqr(AtenTensorHandle self, AtenTensorHandle input2, AtenTensorHandle input3, int32_t left, int32_t transpose, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_segment_reduce(AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* indices, AtenTensorHandle* offsets, int64_t axis, int32_t unsafe, double* initial, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h new file mode 100644 index 0000000000000000000000000000000000000000..7002a7153d83393df0bcfb2fadc90f309dcb69fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -0,0 +1,119 @@ + + +// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. +// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_abs(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_add_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_angle(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumprod(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumsum(AtenTensorHandle self, int64_t dim, int32_t* dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_fill__Scalar(AtenTensorHandle self, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_index_Tensor(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_lu_unpack(AtenTensorHandle LU_data, AtenTensorHandle LU_pivots, int32_t unpack_data, int32_t unpack_pivots, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_select(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_median(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Scalar(double self, AtenTensorHandle exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Tensor_Scalar(AtenTensorHandle self, double exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_searchsorted_Scalar(AtenTensorHandle sorted_sequence, double self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_searchsorted_Tensor(AtenTensorHandle sorted_sequence, AtenTensorHandle self, int32_t out_int32, int32_t right, const char** side, AtenTensorHandle* sorter, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set__source_Tensor(AtenTensorHandle self, AtenTensorHandle source); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_sort_stable(AtenTensorHandle self, int32_t* stable, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_topk(AtenTensorHandle self, int64_t k, int64_t dim, int32_t largest, int32_t sorted, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_triangular_solve(AtenTensorHandle self, AtenTensorHandle A, int32_t upper, int32_t transpose, int32_t unitriangular, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_bicubic2d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_h, double* scales_w, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_linear1d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_upsample_trilinear3d_backward(AtenTensorHandle grad_output, const int64_t* output_size, int64_t output_size_len_, const int64_t* input_size, int64_t input_size_len_, int32_t align_corners, double* scales_d, double* scales_h, double* scales_w, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_as_complex(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_view_as_real(AtenTensorHandle self, AtenTensorHandle* ret0); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..243bfb5fc87aafb0daa7fc6bfe6547579fbea292 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -0,0 +1,67 @@ + + +// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. +// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addmv(AtenTensorHandle self, AtenTensorHandle mat, AtenTensorHandle vec, double beta, double alpha, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cholesky_solve(AtenTensorHandle self, AtenTensorHandle input2, int32_t upper, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_convolution(AtenTensorHandle input, AtenTensorHandle weight, AtenTensorHandle* bias, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_convolution_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle weight, const int64_t** bias_sizes, int64_t bias_sizes_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t transposed, const int64_t* output_padding, int64_t output_padding_len_, int64_t groups, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummax(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_cummin(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_index_put(AtenTensorHandle self, const AtenTensorHandle** indices, int64_t indices_len_, AtenTensorHandle values, int32_t accumulate, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_kthvalue(AtenTensorHandle self, int64_t k, int64_t dim, int32_t keepdim, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_logcumsumexp(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle source, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_masked_scatter_backward(AtenTensorHandle grad_output, AtenTensorHandle mask, const int64_t* sizes, int64_t sizes_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_slice_Tensor(AtenTensorHandle self, int64_t dim, int64_t* start, int64_t* end, int64_t step, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_soft_margin_loss_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle target, int64_t reduction, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_sort(AtenTensorHandle self, int64_t dim, int32_t descending, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_squeeze_dim(AtenTensorHandle self, int64_t dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_uniform(AtenTensorHandle self, double from, double to, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_view_dtype(AtenTensorHandle self, int32_t dtype, AtenTensorHandle* ret0); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/mkldnn_tensor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/mkldnn_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..adf0927ca0a5eb5473cd7499e642a0b42afaa412 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/mkldnn_tensor.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace torch::aot_inductor { + +void* data_ptr_from_mkldnn(at::Tensor* mkldnn_tensor); + +at::Tensor mkldnn_tensor_from_data_ptr( + void* data_ptr, + at::IntArrayRef dims, + at::ScalarType dtype, + at::Device device, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..b0f5b3083cf23272e385dc216654ef6be416e59d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -0,0 +1,153 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include // @manual +#include +#include + +namespace torch::aot_inductor { + +inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) { + os << static_cast(arg_type); + return os; +} + +struct OSSDynamicArg { + OSSDynamicArg( + int arg_index, + DynamicArgType arg_type, + int length, + std::optional> list_item_types = std::nullopt) + : arg_index(arg_index), + arg_type(arg_type), + length(length), + list_item_types(std::move(list_item_types)) {} + int arg_index; + DynamicArgType arg_type; + int length; + std::optional> + list_item_types; // only used for parsing list of optional tensors +}; + +struct OSSTorchBindArg { + OSSTorchBindArg(int arg_index, std::string arg_name) + : arg_index(arg_index), arg_name(std::move(arg_name)) {} + int arg_index; + // arg_name is used to find the corresponding IValue in customObjs_ + std::string arg_name; +}; + +struct OSSOpKernel { + explicit OSSOpKernel(std::string target) : target_(std::move(target)) {} + // Explicitly declare copy and move constructors + OSSOpKernel(const OSSOpKernel&) = default; + OSSOpKernel(OSSOpKernel&&) = default; + // Explicitly declare copy and move assignment operators + OSSOpKernel& operator=(const OSSOpKernel&) = default; + OSSOpKernel& operator=(OSSOpKernel&&) = default; + + std::string target_; + std::vector dynamic_args_; + std::vector torchbind_args_; + std::vector outputs_; + std::vector stack_; + + int num_output_tensors() const { + int num_output_tensors = 0; + for (const auto& output : outputs_) { + if (isTensorType(output.arg_type)) { + num_output_tensors += output.length; + } + } + return num_output_tensors; + } + + int num_output_ints() const { + int num_output_ints = 0; + for (const auto& output : outputs_) { + if (output.arg_type == DynamicArgType::IntType) { + num_output_ints += output.length; + } + } + return num_output_ints; + } + + virtual void run(std::vector& stack) = 0; + virtual c10::FunctionSchema schema() const = 0; + virtual ~OSSOpKernel() = default; +}; + +struct OSSOpKernelOperator : public OSSOpKernel { + OSSOpKernelOperator(std::string target, c10::OperatorHandle op_handle) + : OSSOpKernel(std::move(target)), op_handle_(std::move(op_handle)) {} + + c10::OperatorHandle op_handle_; + void run(std::vector& stack) override { + op_handle_.callBoxed(stack); + } + + c10::FunctionSchema schema() const override { + return op_handle_.schema(); + } +}; + +struct OSSCallTorchBindKernel : public OSSOpKernel { + OSSCallTorchBindKernel(std::string target, torch::jit::Function* method) + : OSSOpKernel(std::move(target)), method_(method) {} + torch::jit::Function* method_; + void run(std::vector& stack) override { + method_->run(stack); + } + + c10::FunctionSchema schema() const override { + return method_->getSchema(); + } +}; + +class OSSProxyExecutor : public ProxyExecutor { + public: + explicit OSSProxyExecutor( + const std::string& json_path, + bool is_cpu, + std::optional> custom_objs = + std::nullopt); + + void call_function( + int extern_node_index, + int num_ints, + int64_t* flatten_int_args, + int num_tensors, + AtenTensorHandle* flatten_tensor_args) override; + + private: + void prefill_stack_with_static_arguments( + size_t index, + const at::TypePtr& schema_arg_type, + const nlohmann::json& serialized_arg, + OSSOpKernel* op_kernel, + const std::string& torchbind_arg_name); + + void get_input_info_from_serialized( + const std::vector& schema_args, + const nlohmann::json& serialized_node, + OSSOpKernel& op_kernel); + + void get_output_info_from_serialized( + const std::vector& schema_returns, + const nlohmann::json& serialized_node, + OSSOpKernel& op_kernel); + + std::unique_ptr get_call_torch_bind_kernel( + const nlohmann::json& serialized_node); + + std::vector> op_kernels_; + std::unique_ptr device_; + std::unordered_map custom_objs_; +}; + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/proxy_executor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/proxy_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..708dc52a760cf046e260819e1dd84a4267639528 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/proxy_executor.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +namespace torch::aot_inductor { + +enum class DynamicArgType : int { + TensorType = 0, + ListTensorType = 1, + ListOptionalTensorType = 2, + IntType = 3, + ListIntType = 4, + NoneType = 5, +}; + +inline bool isTensorType(DynamicArgType arg_type) { + return arg_type == DynamicArgType::TensorType || + arg_type == DynamicArgType::ListTensorType || + arg_type == DynamicArgType::ListOptionalTensorType; +} + +class ProxyExecutor { + public: + ProxyExecutor() = default; + virtual ~ProxyExecutor() = default; + + virtual void call_function( + int extern_node_index, + int num_ints, + int64_t* flatten_int_args, + int num_tensors, + AtenTensorHandle* flatten_tensor_args) = 0; +}; + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/tensor_converter.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/tensor_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..384207e41147cadd4a7ed57bb4a633bb0cb09d2f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/tensor_converter.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +// Functions declared here are not meant to be called from the AOTInductor +// generated model.so + +// unsafe_alloc_new_handles_from_tensors is used for allocating new aten +// tensor objects and return them as a vector of AtenTensorHandle (raw +// pointers), and those pointers will be stolen by model.so. +TORCH_API std::vector unsafe_alloc_new_handles_from_tensors( + const std::vector& tensors); + +// alloc_tensors_by_stealing_from_handles is used for creating a vector of aten +// tensors by stealing from an array of handles. Only the handles are stolen, +// and the array itself is borrowed. +// +// WARNING: Can NOT be called in model.so +TORCH_API std::vector alloc_tensors_by_stealing_from_handles( + AtenTensorHandle* handles, + size_t length); + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/utils.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4f19fd670d0fc2a306514585c7c8151eabd665d9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/aoti_torch/utils.h @@ -0,0 +1,225 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception& e) { \ + LOG(ERROR) << "Exception in aoti_torch: " << e.what(); \ + return AOTI_TORCH_FAILURE; \ + } catch (...) { \ + LOG(ERROR) << "Exception in aoti_torch: UNKNOWN"; \ + return AOTI_TORCH_FAILURE; \ + } \ + return AOTI_TORCH_SUCCESS; + +namespace torch::aot_inductor { + +inline at::Tensor* tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenTensorHandle tensor_pointer_to_tensor_handle(at::Tensor* tensor) { + return reinterpret_cast(tensor); +} + +inline at::Tensor resolve_tensor_dispatch_flags(AtenTensorHandle handle) { + at::Tensor* tensor{tensor_handle_to_tensor_pointer(handle)}; + if (tensor->is_conj() || tensor->is_neg()) { + // If the conjugation or negation dispatch flags are set, runtime dispatch + // handles them by cloning the tensor before passing them to the native ATen + // function. Since the C-shim calls the native function directly, we have + // to handle the flags ourselves, or results will be silently incorrect. + return tensor->clone(); + } + return *tensor; +} + +inline std::optional resolve_tensor_dispatch_flags( + const AtenTensorHandle* handle) { + return handle ? std::make_optional(resolve_tensor_dispatch_flags(*handle)) + : std::nullopt; +} + +inline std::vector resolve_tensor_list_dispatch_flags( + const AtenTensorHandle* handle, + int64_t len) { + std::vector ret{}; + ret.reserve(len); + for (int64_t i{0}; i < len; ++i) { + ret.emplace_back(resolve_tensor_dispatch_flags(handle[i])); + } + return ret; +} + +inline std::vector> resolve_tensor_list_dispatch_flags( + const AtenTensorHandle** handle, + int64_t len) { + std::vector> ret{}; + ret.reserve(len); + for (int64_t i{0}; i < len; ++i) { + ret.emplace_back(resolve_tensor_dispatch_flags(handle[i])); + } + return ret; +} + +inline at::Generator* generator_handle_to_generator_pointer( + AtenGeneratorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenGeneratorHandle generator_pointer_to_generator_handle( + at::Generator* generator) { + return reinterpret_cast(generator); +} + +inline AtenTensorHandle new_tensor_handle(at::Tensor&& tensor) { + at::Tensor* new_tensor = new at::Tensor(std::move(tensor)); + return tensor_pointer_to_tensor_handle(new_tensor); +} + +inline void assert_inf_and_nan( + const std::string& tensor_name, + at::Tensor& check_tensor) { + auto isnan_tensor = check_tensor.isnan(); + if (isnan_tensor.any().item()) { + throw std::runtime_error("At least one NaN in " + tensor_name); + } + auto isinf_tensor = check_tensor.isinf(); + if (isinf_tensor.any().item()) { + throw std::runtime_error("At least one INF in " + tensor_name); + } +} + +// utility functions to convert a pointer to an optional value +template +inline std::optional pointer_to_optional(T* ptr) { + return ptr ? std::make_optional(*ptr) : std::nullopt; +} + +template >> +inline std::optional pointer_to_optional(U* ptr) { + return ptr ? std::make_optional(T(*ptr)) : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional(AtenTensorHandle* ptr) { + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional( + const AtenTensorHandle* ptr) { + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional( + AtenGeneratorHandle* ptr) { + return ptr ? std::make_optional(*generator_handle_to_generator_pointer(*ptr)) + : std::nullopt; +} + +inline std::optional pointer_to_optional_device( + int32_t* device_type, + int32_t device_index) { + return device_type ? std::make_optional(c10::Device( + static_cast(*device_type), + static_cast(device_index))) + : std::nullopt; +} + +// utility functions to convert a pointer to a list +template +struct is_optional : std::false_type {}; +template +struct is_optional> : std::true_type {}; + +template +inline c10::ArrayRef pointer_to_list(T* ptr, int64_t len) { + return c10::ArrayRef(ptr, len); +} + +template < + class T, + class U, + typename = std::enable_if_t>, + typename = std::enable_if_t::value>> +inline std::vector pointer_to_list(U* ptr, int64_t len) { + // std::vector will be implicitly converted to c10::ArrayRef at the call + // site + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(T(ptr[i])); + } + return result; +} + +template ::value>> +inline std::vector pointer_to_list(U** ptr, int64_t len) { + // Here U** denotes a list of optional arguments + // std::vector will be implicitly converted to c10::ArrayRef at the call + // site + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(pointer_to_optional(ptr[i])); + } + return result; +} + +template <> +inline std::vector pointer_to_list( + const AtenTensorHandle* ptr, + int64_t len) { + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(*tensor_handle_to_tensor_pointer(ptr[i])); + } + return result; +} + +template <> +inline std::vector> pointer_to_list( + const AtenTensorHandle** ptr, + int64_t len) { + std::vector> result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(pointer_to_optional(ptr[i])); + } + return result; +} + +template +inline std::array pointer_to_list(const int32_t* ptr) { + std::array result; + std::copy(ptr, ptr + N, result.begin()); + return result; +} + +// Utility function to convert a pointer to an optional list of values +template +inline std::optional> pointer_to_optional_list( + U** ptr, + int64_t len) { + return ptr + ? std::make_optional>(pointer_to_list(*ptr, len)) + : std::nullopt; +} + +} // namespace torch::aot_inductor diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/array_ref.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/array_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..de9a53d7df5fdd4d5c84bb1c3770f55cc72deb6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/array_ref.h @@ -0,0 +1,7 @@ +#pragma once + +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/common.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/common.h new file mode 100644 index 0000000000000000000000000000000000000000..9d9ae16462cc19552b54a08151089d931f775915 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/common.h @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include +#include + +#include +#define PYBIND11_SIMPLE_GIL_MANAGEMENT +#include + +// Include some often-used cpp_wrapper headers, for precompiling. +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; // NOLINT(misc-unused-alias-decls) + +class RAIIPyObject { + public: + RAIIPyObject() = default; + // steals a reference to a PyObject + RAIIPyObject(PyObject* obj) : obj_{obj} {} + RAIIPyObject(const RAIIPyObject& other) : obj_{other.obj_} { + Py_XINCREF(obj_); + } + RAIIPyObject(RAIIPyObject&& other) noexcept { + // refcount doesn't change, and obj_ is currently nullptr + std::swap(obj_, other.obj_); + } + ~RAIIPyObject() { + Py_XDECREF(obj_); + } + RAIIPyObject& operator=(const RAIIPyObject& other) { + if (this != &other) { + Py_XDECREF(obj_); + obj_ = other.obj_; + Py_XINCREF(obj_); + } + return *this; + } + RAIIPyObject& operator=(RAIIPyObject&& other) noexcept { + // refcount to the current object decreases, but refcount to other.obj_ is + // the same + Py_XDECREF(obj_); + obj_ = std::exchange(other.obj_, nullptr); + return *this; + } + operator bool() const noexcept { + return obj_; + } + operator PyObject*() { + return obj_; + } + PyObject* get() { + return obj_; + } + + private: + PyObject* obj_{nullptr}; +}; + +#include +#include +using namespace torch::aot_inductor; + +#include +#include + +// Round up to the nearest multiple of 64 +[[maybe_unused]] inline int64_t align(int64_t nbytes) { + return (nbytes + 64 - 1) & -64; +} diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/cpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..76c2afd91606d637ac1d08b020bff4dc0b5bcf90 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/cpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..782a2b677276a79e2f900266e3b0ab0b335b23a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/cuda.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..c203906bb3f572ddb5da80fb4356c2e4aae84b3c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..29eaadda4f1725727346c460650c7710391238e6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/mps.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/mps.h new file mode 100644 index 0000000000000000000000000000000000000000..4a25af7ac07afeae80f1a5fa989233aaf6dee855 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/mps.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..32bce0f4e749bacc759ff96fe417145a4c7f7972 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h @@ -0,0 +1,5 @@ +#pragma once + +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/mps.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/mps.h new file mode 100644 index 0000000000000000000000000000000000000000..a5847517bebd931e2cbd448ed09e42d30ed39e6d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/mps.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/xpu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/xpu.h new file mode 100644 index 0000000000000000000000000000000000000000..e26dea0f3b6e2d6d978ee2b88a335f0728261035 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/inductor/cpp_wrapper/xpu.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/compilation_unit.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/compilation_unit.h new file mode 100644 index 0000000000000000000000000000000000000000..a07ff6e4ad9f4afeb0b60e942c7b9e485b3b6376 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/compilation_unit.h @@ -0,0 +1,351 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +struct Def; +struct Property; +struct ClassDef; +struct SugaredValue; +struct Resolver; + +using ResolverPtr = std::shared_ptr; +struct Self { + virtual ~Self() = default; + virtual std::shared_ptr makeSugared(Value* v) const = 0; + virtual ClassTypePtr getClassType() const = 0; +}; + +// A CompilationUnit is a list of named Functions +// with helper methods to iterate the list or invoke the function. +// Classes have a CompilationUnit holding the class methods, +// and Modules have a CompilationUnit holding the Functions that +// are used to implement their Methods + +struct TORCH_API CompilationUnit { + enum class FunctionType { Method, Hook, PreHook }; + // constructor that takes a set of functions to compile using the native + // resolver + explicit CompilationUnit(const std::string& source); + CompilationUnit() = default; + + CompilationUnit& operator=(CompilationUnit&&) = default; + CompilationUnit(CompilationUnit&&) = default; + CompilationUnit& operator=(const CompilationUnit&) = delete; + CompilationUnit(const CompilationUnit&) = delete; + + Function* find_function(const c10::QualifiedName& name) const { + auto it = dict_.find(name); + if (it == dict_.end()) { + return nullptr; + } + return functions_[it->second].get(); + } + + Function& get_function(const c10::QualifiedName& name) const { + if (auto r = find_function(name)) { + return *r; + } + TORCH_CHECK(false, "attempted to get undefined function ", name.name()); + } + + void set_optimized(bool o) { + TORCH_WARN( + "CompilationUnit::set_optimized() is deprecated and has no effect. " + "Please use setGraphExecutorOptimize()"); + } + + bool is_optimized() const { + TORCH_WARN( + "CompilationUnit::is_optimized() is deprecated and always returns true. " + "Please use getGraphExecutorOptimize()"); + return true; + } + + // for historic reasons, these are defined in ir_emitter.cpp + // Returns the list of Functions just defined. + std::vector define( + const std::optional& prefix, + const std::vector& properties, + const std::vector& propResolvers, + const std::vector& definitions, + const std::vector& + defResolvers, /* determines how we handle free + variables in each definition*/ + // if non-null, the first argument to each def, is bound to this value + const Self* self, + // see [name mangling] + bool shouldMangle = false, + std::optional operator_set_version = std::nullopt); + + void define_hooks( + const std::optional& prefix, + const std::vector& hookDefs, + const std::vector& hookResolvers, + const std::vector& preHookDefs, + const std::vector& preHookResolvers, + const Self* self, + bool shouldMangle = false); + + // same as above but parse the definitions from source + // Returns the list of Functions just defined. + std::vector define( + // prefix namespace to put all the defined functions into + const std::optional& prefix, + const std::string& source, + const ResolverPtr& resolver, + const Self* self); + + void define_interface( + const c10::QualifiedName& qualifiedName, + const ClassDef& classDef, + ResolverPtr rcb, + bool is_module = false); + + Function* create_function( + c10::QualifiedName name, + std::shared_ptr graph, + bool shouldMangle = false) { + if (shouldMangle) { + name = mangle(name); + } + auto fn = std::make_unique( + std::move(name), std::move(graph), nullptr); + auto ret = fn.get(); + register_function(std::move(fn)); + return ret; + } + + std::vector get_functions() const { + return fmap(functions_, [](const std::unique_ptr& fn) { + return fn.get(); + }); + } + + /// Run a method from this compilation. + /// + /// For example: + /// @code + /// IValue output = module->run("relu_script", a, b); + /// @endcode + /// + /// To get a compile a module from a source string, see torch::jit::compile + /// + /// @param method_name The name of the method to run + /// @param args Arguments to be passed to the method + /// @return An IValue containing the return value (or values if it is a tuple) + /// from the method + template + IValue run_method(const c10::QualifiedName& method_name, Types&&... args) { + return get_function(method_name)({IValue(std::forward(args))...}); + } + + void drop_all_functions() { + dict_.clear(); + functions_.clear(); + } + + /** + * Register a class as being owned by this compilation unit. + */ + void register_type(c10::NamedTypePtr namedType) { + // TODO: class types cannot be redefined because we have no way right now + // of invalidating their methods. NamedTuples are fine though, since they + // don't have methods. + TORCH_CHECK( + 0 == classDict_.count(*namedType->name()), + "class '", + namedType->name()->qualifiedName(), + "' already defined."); + classes_.push_back(std::move(namedType)); + classDict_[*classes_.back()->name()] = classes_.size() - 1; + } + + c10::ClassTypePtr get_class(const c10::QualifiedName& name) const { + auto type = get_type(name); + if (!type) { + return nullptr; + } + return type->cast(); + } + + c10::InterfaceTypePtr get_interface(const c10::QualifiedName& name) const { + auto type = get_type(name); + if (!type) { + return nullptr; + } + return type->cast(); + } + + c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const { + for (const auto& cls : classes_) { + if (cls->name()->qualifiedName() == name.qualifiedName()) { + return cls->expect(); + } + } + return nullptr; + } + + c10::NamedTypePtr get_type(const c10::QualifiedName& name) const { + auto it = classDict_.find(name); + if (it == classDict_.end()) { + return nullptr; + } + return classes_[it->second]; + } + + // For testing: clear all Python-defined classes to ensure that unit tests + // have isolation. + void _clear_python_cu() { + // Delete all the associated class methods + for (const auto& type : classes_) { + if (auto cls = type->cast()) { + for (auto method : cls->methods()) { + // Tombstone the method in the compilation unit. + // Don't erase because the dict_ + auto it = dict_.find(method->qualname()); + if (it != dict_.end()) { + functions_[it->second] = nullptr; + // Erase in our big lookup table + dict_.erase(it); + } + } + // Classes can have multiple pointers to the same hook, + // need to make sure to not delete it twice + std::unordered_set hooks_to_delete; + for (const auto& hook : cls->getForwardHooks()) { + hooks_to_delete.insert(hook); + } + for (const auto& pre_hook : cls->getForwardPreHooks()) { + hooks_to_delete.insert(pre_hook); + } + for (const auto& hook : hooks_to_delete) { + // Tombstone the hook in the compilation unit. + auto it = dict_.find(hook->qualname()); + if (it != dict_.end()) { + functions_[it->second] = nullptr; + // Erase in our big lookup table + dict_.erase(it); + } + } + } + } + classes_.clear(); + classDict_.clear(); + } + + // [Internal Only] Remove method. + // Note Used for freezing. + void unsafeRemoveMethod(const c10::QualifiedName& method_name) { + auto it = dict_.find(method_name); + TORCH_CHECK( + it != dict_.end(), + "method '", + method_name.qualifiedName(), + "' does not exist."); + functions_[it->second] = nullptr; + dict_.erase(it); + } + + // [name mangling] All code objects must have a unique qualified name in a + // CompilationUnit. In Python, sometimes functions won't have unique qualified + // name (for example, nested functions). So we mangle Python functions to + // ensure that they are uniquely named. + // + // We also use mangling to distinguish different Module instances. Since each + // Module is a singleton class instance, different instances of the same + // Python Module will have different types but the same qualified name. + c10::QualifiedName mangle(const c10::QualifiedName& name) const { + auto mangled = name; + while (get_type(mangled) || find_function(mangled)) { + mangled = mangler_.mangle(mangled); + } + return mangled; + } + + private: + std::unique_ptr define( + const std::optional& prefix, + const Def& def, + const ResolverPtr& resolver, + const Self* self, + const std::unordered_map& function_table, + bool shouldMangle = false, + FunctionType type = FunctionType::Method, + std::optional version = std::nullopt) const; + + // Define a property on \p self. + struct PropertyPair; + PropertyPair define_property( + const std::optional& prefix, + const Property& prop, + const ResolverPtr& resolver, + const Self* self, + const std::unordered_map& function_table, + bool shouldMangle = false) const; + + Function& register_function(std::unique_ptr fn) { + TORCH_CHECK( + 0 == dict_.count(fn->qualname().qualifiedName()), + "method '", + fn->qualname().qualifiedName(), + "' already defined."); + functions_.emplace_back(std::move(fn)); + dict_[functions_.back()->qualname()] = functions_.size() - 1; + return *functions_.back(); + } + std::vector> functions_; + // for fast lookup + std::unordered_map dict_; + std::unordered_map classDict_; + + // [class ownership] Right now there are two relationships between classes + // and compilation units: + // 1. Classes have compilation units internally that hold their methods. + // 2. On load, the TypePtrs of any imported classes are owned by the main + // module's compilation unit. + std::vector classes_; + + mutable NameMangler mangler_; +}; + +// An owning pointer to a Function. Just a pair of a raw Function ptr and it's +// owning CU. We need this because pybind requires a ref-counted way to refer to +// Functions. +struct StrongFunctionPtr { + StrongFunctionPtr(std::shared_ptr cu, Function* function) + : cu_(std::move(cu)), function_(function) { + TORCH_INTERNAL_ASSERT(cu_); + TORCH_INTERNAL_ASSERT(function_); + } + std::shared_ptr cu_; + Function* function_; +}; + +namespace script { +// We once had a `script::` namespace that was deleted. This is for backcompat +// of the public API; new code should not use this type alias. +using CompilationUnit = ::torch::jit::CompilationUnit; +} // namespace script +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/function_impl.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/function_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f5d86039ec0aa612c2e588ae146fed1cc53e2f9a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/function_impl.h @@ -0,0 +1,180 @@ +#pragma once + +#include +#include +#include + +namespace torch::jit { + +struct TORCH_API GraphFunction : public Function { + GraphFunction( + c10::QualifiedName name, + std::shared_ptr graph, + std::function function_creator, + std::optional executor_execution_mode = + std::nullopt) + : name_(std::move(name)), + graph_(std::move(graph)), + executor_execution_mode_(executor_execution_mode), + function_creator_(std::move(function_creator)) {} + + bool isGraphFunction() const override { + return true; + } + + void run(Stack& stack) override; + + std::function function_creator() const { + return function_creator_; + } + + c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher taskLauncher = at::launch) override; + + std::shared_ptr graph() const { + return graph_; + } + + std::shared_ptr optimized_graph() const; + + const c10::QualifiedName& qualname() const override { + return name_; + } + + // private/unstable api. sets the initial execution mode + // will not affect executor if there is an existing executor + // created for this function + void _set_initial_executor_execution_mode(ExecutorExecutionMode mode) { + executor_execution_mode_ = mode; + } + // private/unstable api. sets flag of whether or not to ignore amp. + // will not affect executor if there is an existing executor + // created for this function + void _set_ignore_amp(bool ignore_amp) { + force_no_amp_ = ignore_amp; + } + + // if this isn't yet defined, run its method_creator function + void ensure_defined() override; + + size_t num_inputs() const override { + return graph()->inputs().size(); + } + + Function& setSchema(FunctionSchema schema) override { + schema_ = std::make_unique(std::move(schema)); + return *this; + } + + const FunctionSchema& getSchema() const override; + + GraphExecutorState getDebugState() { + return get_executor().getDebugState(); + } + + bool is_optimized() const { + TORCH_WARN( + "GraphFunction::is_optimized() is deprecated and always returns true. " + "Please use getGraphExecutorOptimize()"); + return true; + } + + void check_single_output() { + TORCH_CHECK( + graph()->outputs().size() == 1, + "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs"); + } + + GraphExecutor& get_executor() { + ensure_defined(); + std::lock_guard lock(compile_mutex); + auto& executor = executors_[currentSpecialization()]; + if (executor) { + return *executor; + } + check_single_output(); + const std::string& name = name_.name(); + std::shared_ptr opt_graph = optimized_graph(); + if (!executor_execution_mode_) { + executor = GraphExecutor(opt_graph, name); + } else { + executor = GraphExecutor(opt_graph, name, *executor_execution_mode_); + } + return *executor; + } + + using Function::call; + bool call( + Stack& stack, + std::optional bailOut, + c10::function_ref f) override { + f(get_executor().getPlanFor(stack, bailOut).code); + return true; + } + + void clear_optimized_graphs() { + optimized_graphs_.fill(nullptr); + } + + private: + enum SpecializationKey { + AutocastOff, + CpuAutocastOn, + GpuAutocastOn, + CpuGpuAutocastOn, + + // This provides the number of specializations + // (Must be last entry) + TotalCount + }; + + SpecializationKey currentSpecialization() const; + + private: + c10::QualifiedName name_; + // The original, non-optimized graph + std::shared_ptr graph_; // for debugging and for inlining + + // allows users to specify Simple/Profiling Executor for function + // TODO: add more executors + mutable std::optional executor_execution_mode_; + + // if invoked on a graph that has already traced through amp + // don't invoke amp pass + mutable bool force_no_amp_ = false; + // Optimized graph, computed lazily. Used for inlining. + mutable std::array, SpecializationKey::TotalCount> + optimized_graphs_; + + // GraphFunctions are invokable from multiple threads, so this lock needs to + // be held when we're initializing graph executor for the first time or + // computing the optimized graph. We're using reentrant mutex so that we don't + // need to worry about causing a deadlock by calling one method from another + // (e.g. optimized_graph() from get_executor()). + mutable std::recursive_mutex compile_mutex; + + // executor_[0] - autocast off + // executor_[1] - autocast cpu on + // executor_[2] - autocast gpu on + // executor_[3] - autocast cpu & gpu on + std::array, SpecializationKey::TotalCount> + executors_; + + // an optional function that actually creates the method when + // ensure_defined() is called. This is used by the compiler so + // that it can construct methods out of order + std::function function_creator_; + + // if absent, then we generate a default schema based on the graph + // mutable because getSchema caches the default schema if one is requested + // before a call to setSchema + mutable std::unique_ptr schema_; +}; + +// Short hands for dynamic_cast. +TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept; +TORCH_API GraphFunction& toGraphFunction(Function&); +TORCH_API const GraphFunction& toGraphFunction(const Function&); +} // namespace torch::jit +C10_DECLARE_bool(torch_jit_do_not_store_optimized_graph); diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/method.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/method.h new file mode 100644 index 0000000000000000000000000000000000000000..28675e5bd059f5e876e1b55c94b2c0a705aca28c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/method.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::jit { + +using ObjectPtr = c10::intrusive_ptr; + +// A method in a module, e.g. f in: +// +// class M(ScriptModule): +// @script_method +// def f(self, x): +// ... +// Note: because Method/Module are exposed to python these +// classes use python method naming conventions +struct TORCH_API Method : public torch::IMethod { + Method(ObjectPtr owner, Function* function); + + // the module that contains this method. + Module owner() const; + // the raw objectptr that owns this method, for when the method is owned by a + // torchbind object. + ObjectPtr raw_owner() const; + void run(Stack& stack); + void run(Stack&& stack) { + run(stack); + } + + c10::IValue operator()( + std::vector stack, + const Kwargs& kwargs = Kwargs()) const override; + + // Run method async. Invocation on this function would invokes a JIT + // interpreter that executes ops inline, one by one, on caller's thread. A + // model can utilize async op, i.e. `fork`, to launch an asynchronous task + // which will be launched on provided `taskLauncher`. + c10::intrusive_ptr run_async( + std::vector stack, + const Kwargs& kwargs = Kwargs(), + TaskLauncher taskLauncher = at::launch); + + std::shared_ptr graph() const { + return toGraphFunction(*function_).graph(); + } + + const std::string& name() const override { + return function_->name(); + } + + size_t num_inputs() const { + return function_->num_inputs(); + } + + GraphExecutor& get_executor() { + return toGraphFunction(*function_).get_executor(); + } + + Function& function() const { + return *function_; + } + + private: + void setArgumentNames(std::vector&) const override; + + // Methods are uniqued onwed by a single module. This raw pointer allows + // looking up the module. + ObjectPtr owner_; + + // Underlying unbound function + Function* function_; +}; + +namespace script { +// We once had a `script::` namespace that was deleted. This is for backcompat +// of the public API; new code should not use this type alias. +using Method = ::torch::jit::Method; +} // namespace script + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/module.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/module.h new file mode 100644 index 0000000000000000000000000000000000000000..8e9be1de48a5fd0ae0cc1bdccdc848bb1af81cc2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/module.h @@ -0,0 +1,685 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// This file contains classes which assist in desugaring Python style +// modules and their methods into flattened graphs which don't have any +// function calls. + +namespace torch::jit { + +using ::c10::Argument; +using ::c10::FunctionSchema; +using ::c10::QualifiedName; +// Map which stores filename to content. +using ExtraFilesMap = std::unordered_map; + +using ModulePtr = c10::intrusive_ptr; + +struct Module; + +template +struct slot_list_impl; + +template +struct Named { + std::string name; + T value; +}; + +using NameModule = Named; +using NameValue = Named; +using NameTensor = Named; + +namespace detail { +struct TORCH_API ModulePolicy; +struct TORCH_API ParameterPolicy; +struct TORCH_API AttributePolicy; +struct TORCH_API BufferPolicy; +template +struct NamedPolicy; +} // namespace detail + +using module_list = slot_list_impl; +using named_module_list = + slot_list_impl>; + +using parameter_list = slot_list_impl; +using named_parameter_list = + slot_list_impl>; + +using attribute_list = slot_list_impl; +using named_attribute_list = + slot_list_impl>; + +using buffer_list = slot_list_impl; +using named_buffer_list = + slot_list_impl>; + +using ModuleLookup = std::function&)>; + +struct TORCH_API Module : public Object { + explicit Module(c10::QualifiedName class_name); + Module(std::shared_ptr cu, const c10::ClassTypePtr& type); + Module() = default; + Module(const Module&) = default; + Module& operator=(const Module&) = default; + Module(Module&&) noexcept = default; + Module& operator=(Module&&) noexcept = default; + Module( + c10::QualifiedName, + std::shared_ptr cu, + bool shouldMangle = false); + Module(ModulePtr module_value) : Object(std::move(module_value)) {} + ~Module() = default; + + void set_optimized(bool o) { + TORCH_WARN( + "Module::set_optimized() is deprecated and has no effect. " + "Please use setGraphExecutorOptimize()"); + } + + bool is_optimized() const { + TORCH_WARN( + "Module::is_optimized() is deprecated and always returns true. " + "Please use getGraphExecutorOptimize()"); + return true; + } + + IValue forward(std::vector inputs, const Kwargs& kwargs = Kwargs()) { + return get_method("forward")(std::move(inputs), kwargs); + } + + // In script modules, buffers are Tensors attribute that are _not_ registered + // as parameters. This is different than in nn.Module where there is a special + // register_buffer method. With this simplification, we only need to track + // whether a slot is a parameter to be able to classify it. + void register_buffer(const std::string& name, at::Tensor v) { + bool is_param = false; + bool is_buffer = true; + std::lock_guard lock(*register_mutex_); + type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer); + _ivalue()->setAttr(name, std::move(v)); + } + + void register_parameter( + const std::string& name, + at::Tensor v, + bool is_buffer) { + std::lock_guard lock(*register_mutex_); + type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer); + _ivalue()->setAttr(name, std::move(v)); + } + + void register_attribute( + const std::string& name, + const TypePtr& t, + IValue v, + bool is_param = false, + bool is_buffer = false) { + type()->addOrCheckAttribute(name, t, is_param, is_buffer); + _ivalue()->setAttr(name, std::move(v)); + } + + void register_module(const std::string& name, const Module& module) { + type()->addOrCheckAttribute(name, module.type()); + _ivalue()->setAttr(name, module._ivalue()); + } + + void apply(const std::function& fn); + + buffer_list buffers(bool recurse = true) const; + named_buffer_list named_buffers(bool recurse = true) const; + + module_list children() const; // direct modules + named_module_list named_children() const; + module_list modules() const; // all modules, including this one, recursively + named_module_list named_modules() const; + + // all tensors involved in gradient optimization + parameter_list parameters(bool recurse = true) const; + named_parameter_list named_parameters(bool recurse = true) const; + + // all members of the object, similar to iterating over dir(obj) in python + attribute_list attributes(bool recurse = true) const; + named_attribute_list named_attributes(bool recurse = true) const; + + void dump( + bool print_method_bodies, + bool print_attr_values, + bool print_param_values) const; + + std::string dump_to_str( + bool print_method_bodies, + bool print_attr_values, + bool print_param_values) const; + + /// Enables "training" mode. + void train(bool on = true); + /// Calls train(false) to enable "eval" mode. + /// Do not override this method, override `train()` instead. + void eval() { + train(/*on=*/false); + } + /// True if the module is in training mode. + bool is_training() const { + return attr("training", true).toBool(); + } + + /// Recursively casts all parameters to the given `dtype` and `device`. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + void to(at::Device device, at::ScalarType dtype, bool non_blocking = false); + + /// Recursively casts all parameters to the given dtype. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + void to(at::ScalarType dtype, bool non_blocking = false); + + /// Recursively moves all parameters to the given device. + /// + /// If `non_blocking` is true and the source is in pinned memory and + /// destination is on the GPU or vice versa, the copy is performed + /// asynchronously with respect to the host. Otherwise, the argument has no + /// effect. + void to(at::Device device, bool non_blocking = false); + + void save( + std::ostream& out, + const ExtraFilesMap& extra_files = ExtraFilesMap()) const; + + void save( + const std::string& filename, + const ExtraFilesMap& extra_files = ExtraFilesMap()) const; + + void _save_for_mobile( + std::ostream& out, + const ExtraFilesMap& extra_files = ExtraFilesMap(), + bool save_mobile_debug_info = false, + bool use_flatbuffer = false) const; + + void _save_for_mobile( + const std::string& filename, + const ExtraFilesMap& extra_files = ExtraFilesMap(), + bool save_mobile_debug_info = false, + bool use_flatbuffer = false) const; + + Module copy() const; + + Module deepcopy(std::optional device = std::nullopt) const; + + // Clones both the underlying `ClassType` and the module instance(data), this + // function creates a new `ClassType` and returns a new instance that has the + // same data as the current instance but with the new type, shared ClassType + // will be preserved as well + Module clone(bool inplace = false) const; + + // Clones both the underlying `ClassType` and the module instance(data), this + // function creates a new `ClassType` and returns a new instance that has the + // same data as the current instance but with the new type, shared ClassType + // will be preserved as well. Also allows the caller to specify a set of + // method and attribute names to not clone. + Module clone( + bool inplace, + const std::unordered_set& ignored_method, + const std::unordered_set& ignored_attributes) const; + + void clone_method(const Module& orig, const std::string& name); + + IValue operator()(std::vector inputs); + + template + IValue create_class(const c10::QualifiedName& name, Types&&... args) const { + return create_class(name, {IValue(std::forward(args))...}); + } + + IValue create_class(const c10::QualifiedName& name, Stack stack) const; + + inline bool operator==(const Module& y) const noexcept { + return _ivalue() == y._ivalue(); + } + + void set_delete_memory(std::shared_ptr delete_mem) { + mem_to_delete_ = std::move(delete_mem); + } + + // A set of functions to maintain input shapes through torch.jit.save and + // torch.jit.load. It only works on tensors and lists/dicts of tensors + // because tracing is only supported by these types. + void store_traced_inputs( + const std::string& func_name, + std::vector inputs) { + if (inputs.empty()) { + return; + } + auto c10_inputs = c10::impl::GenericList(AnyType::get()); + for (IValue& value : inputs) { + // Not checking whether this is traceable type as that is already checked + // higher up in the stack and changing that would require a larger + // restructuring. + c10_inputs.emplace_back(std::move(value)); + } + traced_inputs_.insert_or_assign(func_name, c10_inputs); + } + + c10::Dict retrieve_traced_inputs() + const { + return traced_inputs_; + } + + private: + Module clone_impl( + std::unordered_map& type_remap, + bool inplace, + IValue::HashIdentityIValueMap memo, + const std::unordered_set& ignored_methods, + const std::unordered_set& ignored_attributes) const; + + void clone_method( + const Module& orig, + const Function& method, + const std::unordered_map& type_remap); + + c10::QualifiedName getNameForMethod(std::string basename) const { + return QualifiedName(*type()->name(), std::move(basename)); + } + + void to_impl( + const std::optional& device, + const std::optional& dtype, + bool non_blocking); + + // Extra handle for the module to delete when itself is deleted + std::shared_ptr mem_to_delete_; + + // Map of function names to the traced inputs that they have been traced with + c10::Dict traced_inputs_; + + // Mutex to keep registring buffer or parameter thread safe. + std::shared_ptr register_mutex_ = std::make_shared(); +}; + +// C++ equivalent api of `torch.jit.freeze`. See documentation there for +// details. +TORCH_API Module freeze( + const Module& module, + const std::optional>& preserved_attrs = + std::nullopt, + bool optimize_numerics = true); + +// C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation +// there for details. +TORCH_API Module optimize_for_inference( + Module& module, + const std::vector& other_methods = {}); + +enum class FusionBehavior { STATIC, DYNAMIC }; + +using FusionStrategy = std::vector>; +// clang-format off +/* +Sets the type and number of specializations that can occur during fusion. + +Usage: provide a list of pairs (type, depth) where type is one of STATIC or DYNAMIC +and depth is an integer. + +Behavior - static vs dynamic: + In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined + based on some initial profiling runs. + In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple + shapes are possible. + +In both cases, we also recompile on new striding behavior, device, or dtype. + +Behavior - fallback functions & depth: + When an input doesn't match the format required by the specialized compiled op, it will run + a fallback function. Fallback functions are recursively be compiled and specialized based + on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to + limit the number of specializations that can be compiled, before giving up on recompiling and + falling back to a completely un-fused, un-specialized implementation. + +The list of (type, depth) pairs controls the type of specializations and the number of +specializations. For example: [(STATIC, 2), (DYNAMIC, 2)] indicates that the first +two specializations will use static fusions, the following two specializations will use +dynamic fusion, and any inputs that satisfy none of the 4 options will run an +unfused implementation. + +NB: in the future, if more as more fusion backends are added there may be more granular +apis for specific fusers. +*/ +// clang-format on +TORCH_API FusionStrategy getFusionStrategy(); +// returns previous strategy +TORCH_API FusionStrategy setFusionStrategy(FusionStrategy& fusion_strategy); + +namespace detail { + +struct TORCH_API SlotCursor { + Module module_; + int64_t i_; // slot offset, -1 indicates the module itself +}; + +} // namespace detail + +// This iterator allows the (optionally recursive) enumeration of +// the members of a Module. It performs a depth-first pre-order +// traversal of the module. The Policy template parameter determines +// which slots of the object should be included. For instance, +// when iterating parameters, we return the parameter tensors, +// but skip modules, buffers, and other attributes. +// See ModulePolicy for comments about Policy object's API. +template +struct slot_iterator_impl { + using SlotCursor = detail::SlotCursor; + using value_type = typename Policy::value_type; + slot_iterator_impl( + Module root, + bool recurse, // if true, do a depth-first search, otherwise, just look at + // slots of root + bool return_module) // if true include root itself as the first thing + // visited (used in modules()) + : cursors_({SlotCursor{std::move(root), return_module ? -1 : 0}}), + recurse_(recurse) { + // advance iterator to first valid element (or the end, if empty) + while_not_valid_next(); + } + // empty cursors_, represents end of iteration + slot_iterator_impl() : recurse_(false) {} + value_type operator*() const { + return Policy::create(cursors_, cur()); + } + value_type operator->() const { + return **this; + } + slot_iterator_impl& operator++() { + next_valid(); + return *this; + } + slot_iterator_impl operator++(int) { + // this is really expensive, should we delete it so people don't use it + // instead of prefix? + slot_iterator_impl old = *this; + ++(*this); + return old; + } + + private: + // return_module() is a corner case where instead of returning a submodule + // of root, we are returning root itself, because we are iterating modules(), + // which contains the root module itself. + // It is represented with a single SlotCursor whose index is -1. + bool return_module() const { + return top().i_ == -1; + } + const SlotCursor& top() const { + return cursors_.back(); + } + SlotCursor& top() { + return cursors_.back(); + } + IValue cur() const { + return return_module() ? top().module_._ivalue() + : top().module_._ivalue()->getSlot(top().i_); + } + + // advance to the next slot in a depth first pre-order traversal of the + // modules slots. This function does not guarantee the next slot is a + // valid element of the iteration. That is done by valid(). + // invariant: !cursors_.empty() + void next() { + // we just returned the module itself, advance i_ to 0 so we are now + // at the first slot of the module. + if (return_module()) { + ++top().i_; + return; + } + // the last traversal action advanced beyond the number of slots in the + // module so continue the iteration in the parent. + if (top().i_ >= int64_t(top().module_._ivalue()->type()->numAttributes())) { + cursors_.pop_back(); + if (!cursors_.empty()) { + ++top().i_; + } + return; + } + // if the current thing is a module, we have to scan it for recursive + // traversals. We do this by adding a new SlotCursor to track the traversal. + if (recurse_ && + top().module_._ivalue()->type()->getAttribute(top().i_)->is_module()) { + cursors_.emplace_back(SlotCursor{cur().toModule(), 0}); + return; + } + // common case: advance to the next slot. + ++top().i_; + } + // is the current position of the iterator a valid one? + // otherwise, we have to continue advancing. + bool valid() const { + return top().i_ < + int64_t(top().module_._ivalue()->type()->numAttributes()) && + Policy::valid( + top().module_._ivalue()->type(), + top().i_, + top().module_._ivalue()->getSlot(top().i_)); + } + void while_not_valid_next() { + // advance iteration until we are either at the end (cursors_.empty()) + // or in a valid state. return_module() is a special case, + // and is always considered valid, regardless of Policy, because it is + // it is only true when we are iterating modules. + while (!cursors_.empty() && !return_module() && !valid()) { + next(); + } + } + void next_valid() { + // avoid crashing if this is empty + if (cursors_.empty()) { + return; + } + // advance to next element, which is maybe not valid + next(); + while_not_valid_next(); + } + + std::vector cursors_; + bool recurse_; + + friend inline bool operator!=( + const slot_iterator_impl& a, + const slot_iterator_impl& b) { + // we are finished iteration when we have no more iteration SlotCursors. + // end is always an empty iterator with no cursors. + return (a.cursors_.empty() != b.cursors_.empty()); + } +}; + +// This type represents lists of parameters, attributes, and +// submodules contained in the module. It is abstract because +// they are not stored directly in std::vectors but inside the +// module's IValue object itself. +template +struct slot_list_impl { + using iterator = slot_iterator_impl; + using const_iterator = slot_iterator_impl; + using value_type = typename iterator::value_type; + slot_iterator_impl begin() const { + return slot_iterator_impl(module_, recurse_, return_module_); + } + slot_iterator_impl end() const { + return slot_iterator_impl(); + } + size_t size() const { + if (!size_) { + size_ = size_t(0); + for ([[maybe_unused]] const value_type& _ : *(this)) { + ++*size_; + } + } + return *size_; + } + + slot_list_impl(Module module, bool recurse, bool return_module) + : module_(std::move(module)), + recurse_(recurse), + return_module_(return_module), + size_(std::nullopt) { + if (!recurse && !return_module && Policy::all_slots) { + size_ = module_.num_slots(); + } + } + + private: + Module module_; + bool recurse_; + bool return_module_; + // size of this list, cached on first request + // when we need to filter the slot list + mutable std::optional size_; + friend struct Module; +}; + +namespace detail { + +// slot_iterator_impl always iterate over all the slots in a module, +// the Policy template argument determines slots should be returned and their +// types +struct TORCH_API ModulePolicy { + // the type of the value being returned + using value_type = Module; + + // the logic for creating the type being returned, given the raw IValue + // of that object. + static value_type create( + const std::vector& cursors, + IValue v) { + return Module(std::move(v).toObject()); + } + // is slot i in typ something that this iterator should return, otherwise, + // we skip it. + static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { + return typ->getAttribute(i)->is_module(); + } + // are we going to return everything? If so, we can optimize the calculate + // of the size of the list. + static constexpr bool all_slots = false; +}; + +struct TORCH_API ParameterPolicy { + using value_type = at::Tensor; + static value_type create( + const std::vector& cursors, + IValue v) { + return std::move(v).toTensor(); + } + static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { + return typ->is_parameter(i) && v.isTensor(); + } + static constexpr bool all_slots = false; +}; + +struct TORCH_API BufferPolicy { + using value_type = at::Tensor; + static value_type create( + const std::vector& cursors, + IValue v) { + return std::move(v).toTensor(); + } + static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { + return typ->getAttribute(i)->isSubtypeOf(*TensorType::get()) && + typ->is_buffer(i); + } + static constexpr bool all_slots = false; +}; + +struct TORCH_API AttributePolicy { + using value_type = IValue; + static value_type create( + const std::vector& cursors, + IValue v) { + return v; + } + static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { + return true; + } + static constexpr bool all_slots = true; +}; + +// take a Policy object, and make a version of it that returns the slot. +// along with the fully qualified name of that slot. This is used for the named_ +// variants like named_parameters(). +template +struct NamedPolicy { + using value_type = Named; + static value_type create( + const std::vector& cursors, + IValue v) { + std::string name; + if (cursors.size() == 1) { + name = (cursors.back().i_ == -1) ? "" : nameFragment(cursors.back()); + } else { + std::ostringstream ss; + for (const auto i : c10::irange(cursors.size())) { + if (i > 0) { + ss << "."; + } + ss << nameFragment(cursors[i]); + } + name = ss.str(); + } + return value_type{std::move(name), Policy::create(cursors, std::move(v))}; + } + static bool valid(const ClassTypePtr& t, size_t i, const IValue& v) { + return Policy::valid(t, i, v); + } + static constexpr bool all_slots = Policy::all_slots; + + private: + static std::string nameFragment(const detail::SlotCursor& f) { + return f.module_.type()->getAttributeName(f.i_); + } +}; + +} // namespace detail + +TORCH_API bool& getInlineEverythingMode(); + +namespace script { +// We once had a `script::` namespace that was deleted. This is for backcompat +// of the public API; new code should not use this type alias. +using Module = ::torch::jit::Module; +using ExtraFilesMap = ::torch::jit::ExtraFilesMap; +} // namespace script + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/object.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/object.h new file mode 100644 index 0000000000000000000000000000000000000000..8f0d11d718747412bd88ec87c237e1bddc26c1bf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/api/object.h @@ -0,0 +1,200 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch::jit { + +struct Resolver; +using ResolverPtr = std::shared_ptr; + +using ObjectPtr = c10::intrusive_ptr; + +// Throw this in C++ land if `attr` fails. This will be converted to a Python +// AttributeError by the Python binding code +class ObjectAttributeError : public std::runtime_error { + public: + ObjectAttributeError(const std::string& what) : std::runtime_error(what) {} +}; + +struct TORCH_API Object { + Object() = default; + Object(const Object&) = default; + Object& operator=(const Object&) = default; + Object(Object&&) noexcept = default; + Object& operator=(Object&&) noexcept = default; + Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {} + Object(std::shared_ptr cu, const c10::ClassTypePtr& type); + Object( + c10::QualifiedName, + std::shared_ptr cu, + bool shouldMangle = false); + + ObjectPtr _ivalue() const { + TORCH_INTERNAL_ASSERT(_ivalue_); + return _ivalue_; + } + + c10::ClassTypePtr type() const { + return _ivalue()->type(); + } + + struct Property { + std::string name; + Method getter_func; + std::optional setter_func; + }; + + void setattr(const std::string& name, c10::IValue v) { + if (_ivalue()->type()->hasConstant(name)) { + TORCH_CHECK( + false, + "Can't set constant '", + name, + "' which has value:", + _ivalue()->type()->getConstant(name)); + } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) { + const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot); + TORCH_CHECK( + v.type()->isSubtypeOf(*expected), + "Expected a value of type '", + expected->repr_str(), + "' for field '", + name, + "', but found '", + v.type()->repr_str(), + "'"); + _ivalue()->setSlot(*slot, std::move(v)); + } else { + TORCH_CHECK(false, "Module has no attribute '", name, "'"); + } + } + + c10::IValue attr(const std::string& name) const { + if (auto r = _ivalue()->type()->findAttributeSlot(name)) { + return _ivalue()->getSlot(*r); + } + if (auto r = _ivalue()->type()->findConstantSlot(name)) { + return _ivalue()->type()->getConstant(*r); + } + std::stringstream err; + err << _ivalue()->type()->repr_str() << " does not have a field with name '" + << name.c_str() << "'"; + throw ObjectAttributeError(err.str()); + } + + c10::IValue attr(const std::string& name, c10::IValue or_else) const { + if (auto r = _ivalue()->type()->findAttributeSlot(name)) { + return _ivalue()->getSlot(*r); + } + if (auto r = _ivalue()->type()->findConstantSlot(name)) { + return _ivalue()->type()->getConstant(*r); + } + return or_else; + } + + bool hasattr(const std::string& name) const { + return _ivalue()->type()->hasAttribute(name) || + _ivalue()->type()->hasConstant(name); + } + + // each object owns its methods. The reference returned here + // is guaranteed to stay valid until this module has been destroyed + Method get_method(const std::string& name) const { + if (auto method = find_method(name)) { + return *method; + } + TORCH_CHECK(false, "Method '", name, "' is not defined."); + } + + const std::vector get_methods() const { + return c10::fmap(type()->methods(), [&](Function* func) { + return Method(_ivalue(), func); + }); + } + + bool has_property(const std::string& name) const { + for (const auto& prop : type()->properties()) { + if (prop.name == name) { + return true; + } + } + return false; + } + + const Property get_property(const std::string& name) const { + for (const auto& prop : type()->properties()) { + if (prop.name == name) { + std::optional setter = std::nullopt; + if (prop.setter) { + setter = Method(_ivalue(), prop.setter); + } + return Property{ + prop.name, Method(_ivalue(), prop.getter), std::move(setter)}; + } + } + TORCH_CHECK(false, "Property '", name, "' is not defined."); + } + + const std::vector get_properties() const { + return c10::fmap(type()->properties(), [&](ClassType::Property prop) { + std::optional setter = std::nullopt; + if (prop.setter) { + setter = Method(_ivalue(), prop.setter); + } + return Property{ + std::move(prop.name), + Method(_ivalue(), prop.getter), + std::move(setter)}; + }); + } + + std::optional find_method(const std::string& basename) const; + + /// Run a method from this module. + /// + /// For example: + /// @code + /// IValue output = module->run("relu_script", a, b); + /// @endcode + /// + /// To get a compile a module from a source string, see torch::jit::compile + /// + /// @param method_name The name of the method to run + /// @param args Arguments to be passed to the method + /// @return An IValue containing the return value (or values if it is a tuple) + /// from the method + template + IValue run_method(const std::string& method_name, Types&&... args) { + return get_method(method_name)({IValue(std::forward(args))...}); + } + + // so that C++ users can easily add methods + void define(const std::string& src, const ResolverPtr& resolver = nullptr); + + size_t num_slots() const { + return _ivalue()->slots().size(); + } + + // shallow copy the object + Object copy() const; + + // Copies all the attributes of the object recursively without creating new + // `ClassType`, including deepcopy of Tensors + Object deepcopy() const; + + private: + // mutable be we lazily initialize in module_object. + mutable ObjectPtr _ivalue_; +}; + +namespace script { +// We once had a `script::` namespace that was deleted. This is for backcompat +// of the public API; new code should not use this type alias. +using Object = ::torch::jit::Object; +} // namespace script +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend.h new file mode 100644 index 0000000000000000000000000000000000000000..519220d63fc89c5b6a4abeecd18c541e50473e3a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::jit { +namespace { +inline c10::FunctionSchema getIsAvailableSchema() { + c10::Argument self("self", c10::AnyType::get()); + c10::Argument available("available", c10::BoolType::get()); + c10::FunctionSchema preprocessor_schema( + "is_available", + /*overload_name=*/"", + /*arguments=*/{self}, + /*returns=*/{available}); + return preprocessor_schema; +} + +constexpr static auto kBackendsNamespace = "__backends__"; + +inline c10::FunctionSchema getCompileSchema() { + c10::Argument self("self", c10::AnyType::get()); + c10::Argument mod("processed", c10::AnyType::get()); + auto any_dict_ty = + c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); + c10::Argument method_compile_spec("method_compile_spec", any_dict_ty); + c10::Argument handles("handles", any_dict_ty); + + c10::FunctionSchema compile_schema( + "compile", + /*overload_name=*/"", + /*arguments=*/{self, mod, method_compile_spec}, + /*returns=*/{handles}); + return compile_schema; +} + +inline c10::FunctionSchema getExecuteSchema() { + auto any_list_ty = c10::ListType::create(c10::AnyType::get()); + c10::Argument self("self", c10::AnyType::get()); + c10::Argument handle("handle", c10::AnyType::get()); + c10::Argument input("input", any_list_ty); + c10::Argument output("output", any_list_ty); + return c10::FunctionSchema( + "execute", + /*overload_name=*/"", + /*arguments=*/{self, handle, input}, + /*returns=*/{output}); +} + +template +std::function getIsAvailableFunc() { + return [](Stack& stack) { + auto self = pop(stack).toCustomClass(); + auto ret = self->is_available(); + push(stack, ret); + }; +} + +template +std::function getCompileFunc() { + return [](Stack& stack) { + auto method_compile_spec = pop(stack).toGenericDict(); + auto processed = pop(stack); + auto self = pop(stack).toCustomClass(); + auto ret = self->compile(processed, method_compile_spec); + push(stack, ret); + }; +} + +template +std::function getExecuteFunc() { + return [](Stack& stack) { + auto args = pop(stack); + auto handle = pop(stack); + auto self = pop(stack); + auto backend = self.toCustomClass(); + auto res = backend->execute(handle, args.toList()); + push(stack, res); + }; +} +} // namespace + +// Static registration API for backends. +template +class backend { + static_assert( + std::is_base_of_v, + "torch::jit::backend requires T to inherit from PyTorchBackendInterface"); + std::string backend_name_; + + public: + // Registers a new backend with /p name, and the given /p preprocess + // function. + backend(const std::string& name) : backend_name_(name) { + static auto cls = torch::class_(kBackendsNamespace, name) + .def(torch::init<>()) + ._def_unboxed( + "is_available", + getIsAvailableFunc(), + getIsAvailableSchema()) + ._def_unboxed( + "compile", + getCompileFunc(), + getCompileSchema()) + ._def_unboxed( + "execute", + getExecuteFunc(), + getExecuteSchema()); + } +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_debug_handler.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_debug_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..4128832e7a078a456e247544724247c69dd184f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_debug_handler.h @@ -0,0 +1,138 @@ +#pragma once +#include + +#include +#include +#include + +#include + +namespace torch::jit { + +/* + * BackendDebugHandleManager is responsible for issuing debug handles to + * backends. Debug handles are associated with nodes of a graph. + * BackendDebugHandleManager also maintains a map + * [debug-handle, DebugInfoTuple = {source range, inlined callstack ptr]} that + * will help generate a callstack for exception raised using debug handles. + * Effectively debug handles are something that is given to backend and later + * when an exception occurs in the backend, backend can tell, using debug + * handle, that an exception occurred here. Then the runtime can generate + * callstack correspoding to the exception. + * There are two parts to BackendDebugHandleManager: + * 1. static std::atomic debug_handle + * 2. Map of [debug-handle, DebugInfoTuple] + * + * About 1: + * Why do they have to be unique. The reason is that by ensuring + * uniqueness of debug handles, we remove the burden of another layer of + * mapping where we need to say this set of debug handles were generated for + * this lowered module or this bytecode function. This simplifies the API for + * serialization since debug handles can uniquely identify DebugInfoTuple. + * Thus simplifies the runtime API for throwing exception. Exception throwing + * only needs to know debug_handle and not which module or method threw it. + * There are 2 issues to keep in mind, though,for static std::atomic + * debug_handle: A. Performance implications of using atomic variable. However + * this is only used for compilation so we assume to absorb some of that + * penalty. Plus if there is no contention then we should have less to worry + * about. B. If repeated compilation is part of a long running process then we + * may overflow int64_t. We may detect and fail on this. For now this is not + * done. + * + * Now about 2: + * There are two usecases for [debug-handle, DebugInfoTuple] + * A. During bytecode generation the DebugInfoTuple corresponding to the nodes + * of the inlined graph being serialized, are stored in this object and a + * unique debug handle is returned. This unique debug handle is stored in + * mobile_debug info for pytorch lite models. It will be used for raising + * exceptions as well as profiling. B. During backend lowering, each backend's + * preprocess/compile method can compile method's graph and serialize those + * methods. Once the method is lowered to backend, graph is essentially lost. + * Without access to graph it is hard to generate model level debug info. Thus + * the debug handles provide a way to map nodes of the graph to the model level + * debug info. + * + * During byte-code model serialization, [debug-handle, DebugInfoTuple] is + * serialized. Now we know a. debug handles and b. how to map debug handles to + * model source code. Thus we can either do eager symbolication by converting + * debug handles to corresponding source code at runtime, or do lazy + * symbolicattion offline. + * + * Note that it is not necessary to serialize [debug-handle, DebugInfoTuple] + * corresponding to lowered backend if the lowering process, that is + * preprocess/compile, and execution happens in the same session, then eager + * symbolication can be employed. + * + * Now how does BackendDebugHandleManager capture all of the above? + * By providing two API. + * 1. getNextDebugHandle which given a Node* returns a unique debug handle, + * that will uniquely identify DebugInfoTuple. + * and + * 2. getCallStackPtrMap which returns the map + * [debug-handle, DebugInfoTuple] + * + * 1 provides debug handles to backends and 2 provides runtime a way to map + * debug handles to source level debug info. + * + * So why does debug handle map to DebugInfoTuple = {source range and inlined + * cs}? {debug_handle, source_range_tag, serialized_callstack} Take this + * example: class L(nn.Module): def __init__(self) -> None: + * ... + * def forward(self, x): + * return x * 5 + * class M(nn.Module): + * def __init__(self) -> None: + * ... + * def forward(self, x): + * return x - 2 + * class N(nn.Module): + * def __init__(self) -> None: + * self.m = M() + * def forward(self, x): + * return self.m(x) + 3 + * m = torch.jit.script(N()) + * Once you inline m's forward method, m.forward.graph will look something + * like this + * graph(%self...): + * %x = aten::mul(..) + * %x = aten::sub(x, ..) + * %y = aten::add(x, ..) + * .. + * Inlined callstack ptr for these two nodes will look like: + * aten::mul's inlined CS (callstack): [N.forward, source range] -> [M.forward, + * source range] aten::sub's inlined CS (callstack): [N.forward, source range] + * aten::add's inlined CS: null + * mul node's inlined CS contains only information about the callsites' source + * range The information about mul node's source range ('return x * 5') is not + * available in its inlined CS. It is rather part of node's source range + * instead of inlined CS. Thus to get full stack: [N.forward, source range] -> + * [M.forward, source range] -> [aten::mul's source range] We need to track + * mul's source range and inlined CS both. + */ + +using BackendDebugInfoMapType = + std::unordered_map; + +/* + * This class is used to generate debug info map. + * backend's preprocess will call generate_debug_handles (see + * backend_detail.cpp), which uses debug_handle_manager to generate debug + * handles. When lowering process finishes, calling stopRecording will + * return debug info map from debug_handle_manager + */ +class TORCH_API BackendDebugInfoRecorder { + public: + BackendDebugInfoRecorder() = default; + int64_t getNextDebugHandle(const Node* node); + // Reason this is not done as RAII is that work done in stopRecording + // can throw, and throwing with dtor will call terminate and thus voids any + // exception catching at a higher level. + BackendDebugInfoMapType stopRecording(); + NodeToDebugHandle generate_debug_handles(const std::shared_ptr& graph); + + private: + static std::atomic unique_debug_handle_; + BackendDebugInfoMapType handles_to_inlined_callstack_ptrs_; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_debug_info.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_debug_info.h new file mode 100644 index 0000000000000000000000000000000000000000..d6740b6c50466abb9a322bdda302ecef2007831f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_debug_info.h @@ -0,0 +1,63 @@ +#pragma once + +#ifndef BUILD_LITE_INTERPRETER +#include +#endif +#include + +namespace torch::jit { + +constexpr static auto kBackendUtilsNamespace = "backendutils"; +constexpr static auto kBackendDebugInfoClass = "BackendDebugInfo"; + +#ifndef BUILD_LITE_INTERPRETER +/* + * Custom class for holding debug information in lowered modules, intended + * purely for keeping this information to be later serialized outside of the + * lowered module itself. + * Its usage pattern is: + * 1. LoweredModule declares an instance of this class in __backend_debug_info + * 2. During serialization, __backend_debug_info is used to obtain the debug + * information. + * 3. The contents of LoweredModule.__backend_debug_info are not serialized + * within the LoweredModule itself. + */ +class TORCH_API PyTorchBackendDebugInfo : public torch::CustomClassHolder { + public: + PyTorchBackendDebugInfo() = default; + + std::optional& getDebugInfoMap() { + return debug_info_map_; + } + + void setDebugInfoMap(BackendDebugInfoMapType&& debug_info_map) { + debug_info_map_ = std::move(debug_info_map); + } + + private: + std::optional debug_info_map_; +}; + +#else + +/* + * Dummy instance exists for the following reason: + * __backend_debug_info is of type BackendDebugInfo which is a torchbind' + * class backed by cpp class PyTorchBackendDebugInfo. + * PyTorchBackendDebugInfo, depends on ir.h., scope.h, source_range etc. + * We dont include this on lite interpreter side. Thus on lite interpreter side + * we cannot have valid definition of PyTorchBackendDebugInfo. However we do not + * need valid instance of __backend_debug_info in lite interpreter anyway as we + * dont serialize this info as part of LowerdModule as mentioned ealrier. + * However since LoweredModule has registered attribute of __backend_debug_info + * we still need to make sure that BackendDebugInfo is registered with + * TorchScript. However in this instance it does not have to be backed by + * PyTorchBackendDebugInfo, so we create a dummy PyTorchBackendDebugInfoDummy + * just for this purpose. + */ +class PyTorchBackendDebugInfoDummy : public torch::CustomClassHolder { + public: + PyTorchBackendDebugInfoDummy() = default; +}; +#endif +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_detail.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_detail.h new file mode 100644 index 0000000000000000000000000000000000000000..e69a93ebb148ec130b1f8d3f2523b53258b93c42 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_detail.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include + +#include + +namespace torch::jit { + +using DebugHandleType = int64_t; + +using NodeToDebugHandle = std::unordered_map; + +using BackendDebugHandleGenerator = + std::function&)>; + +namespace detail { + +using BackendPreprocessFunction = std::function&, + const BackendDebugHandleGenerator& generate_debug_handles)>; + +TORCH_API void registerBackendPreprocessFunction( + const std::string& name, + const BackendPreprocessFunction& preprocess); + +bool hasBackendPreprocessFunction(const std::string& name); + +BackendPreprocessFunction getBackendPreprocessFunction(const std::string& name); + +TORCH_API Module codegen_backend_module( + const std::string& backend_name, + const Module& orig_module, + const c10::Dict& method_compile_spec, + const c10::DictTypePtr& any_dict_ty); +} // namespace detail +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_exception.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_exception.h new file mode 100644 index 0000000000000000000000000000000000000000..d964f1bfcf008625da4db52b643c664d10c6f6cf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_exception.h @@ -0,0 +1,56 @@ +#pragma once +#include + +#include + +namespace c10 { +class TORCH_API BackendRuntimeException : public c10::Error { + public: + // Use debug_handle to throw exception + BackendRuntimeException( + SourceLocation loc, + std::string msg, + int64_t debug_handle) + : c10::Error(loc, std::move(msg)) { + debug_handles.push_back(debug_handle); + } + // If rethrowing, can push another debug_handle + // This is useful in couple of scenarios. + // 1. A submodule is lowered and lite interperter has CallMethod + // to lowered module's method. In this case lowered module will throw with + // a handle, plus there will be another debug handle corresponding + // to the CallMethod node in lite interpreter. Both together give complete + // trace. This function allows lite interpreter to rethrow with debug + // handle it has for CallMethod. + // 2. Another scenarios is when lite interperter can make function calls or + // the lowered backend also has function call ability. Thus we have + // multiple function frames. Now we need a stack of handles to symbolicate + // entire stack trace. + void pushDebugHandle(int64_t debug_handle) { + debug_handles.push_back(debug_handle); + } + const std::vector& getDebugHandles() { + return debug_handles; + } + + private: + // Stores stack of debug handles. + std::vector debug_handles; +}; + +} // namespace c10 +#define TORCH_DELEGATED_BACKEND_THROW(cond, msg, debug_handle) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + throw ::c10::BackendRuntimeException( \ + {__func__, __FILE__, static_cast(__LINE__)}, \ + msg, \ + debug_handle); \ + } + +#define TORCH_DELEGATED_BACKEND_RETHROW(e, debug_handle) \ + do { \ + e.pushDebugHandle(debug_handle); \ + throw; \ + } while (false) + +#define DEBUG_HANDLE_UNKNOWN -1 diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_init.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_init.h new file mode 100644 index 0000000000000000000000000000000000000000..7f2aac18bd04f8b71f8706d1256372f5a8f7c5ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_init.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace torch::jit { +// Initialize Python bindings for JIT to_ functions. +void initJitBackendBindings(PyObject* module); +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_interface.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..331497f929d4c2efd4075374adf4f11d862027f8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_interface.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace torch::jit { + +// Interface for a JIT backend. +class TORCH_API PyTorchBackendInterface : public torch::CustomClassHolder { + public: + PyTorchBackendInterface() noexcept; + ~PyTorchBackendInterface() override; + + // Returns true if the backend is available to process delegation calls. + virtual bool is_available() = 0; + + // Compile the module contained in \p processed using the details provided in + // \p method_compile_spec for each module method that should be compiled for + // the backend. \p method_compile_spec should be of type Dict. + // \returns a dictionary of type Dict that contains a backend + // handle each method that can run on the backend (i.e. each key in \p + // method_compile_spec). + virtual c10::impl::GenericDict compile( + c10::IValue processed, + c10::impl::GenericDict method_compile_spec) = 0; + + // Execute the method specified by \p handle using \p inputs. \returns the + // outputs as a tuple. + virtual c10::impl::GenericList execute( + c10::IValue handle, + c10::impl::GenericList inputs) = 0; +}; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_preprocess.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_preprocess.h new file mode 100644 index 0000000000000000000000000000000000000000..da4ebd5a93754a58901eff92f2c96e7a87f70cce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_preprocess.h @@ -0,0 +1,16 @@ +#pragma once + +#include +namespace torch::jit { +class backend_preprocess_register { + std::string backend_name_; + + public: + backend_preprocess_register( + const std::string& name, + const detail::BackendPreprocessFunction& preprocess) + : backend_name_(name) { + detail::registerBackendPreprocessFunction(name, preprocess); + } +}; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_resolver.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..9dd4483725766400afe341a3c4e311e77247fa90 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/backend_resolver.h @@ -0,0 +1,8 @@ +#pragma once + +#include + +namespace torch::jit { +// Create a Resolver for use in generating LoweredModules for specific backends. +TORCH_API std::shared_ptr loweredModuleResolver(); +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/cpp/context.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/cpp/context.h new file mode 100644 index 0000000000000000000000000000000000000000..a07a8b81fc7d2886a4306ac4b73bafa1203a2fe6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/cpp/context.h @@ -0,0 +1,22 @@ +#ifndef PTM_COREML_Context_h +#define PTM_COREML_Context_h + +#include + +namespace torch::jit::mobile::coreml { + +struct ContextInterface { + virtual ~ContextInterface() = default; + virtual void setModelCacheDirectory(std::string path) = 0; +}; + +class BackendRegistrar { + public: + explicit BackendRegistrar(ContextInterface* ctx); +}; + +void setModelCacheDirectory(std::string path); + +} // namespace torch::jit::mobile::coreml + +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h new file mode 100644 index 0000000000000000000000000000000000000000..0221e84f9ac822d47b1a6d579d6dcc18edb0a9a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h @@ -0,0 +1,22 @@ +#import + +#include + +NS_ASSUME_NONNULL_BEGIN + +@interface PTMCoreMLCompiler : NSObject + ++ (void)setCacheDirectory:(const std::string&)dir; + ++ (NSString*)cacheDirectory; + ++ (BOOL)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID; + ++ (nullable MLModel*)loadModel:(const std::string)modelID + backend:(const std::string)backend + allowLowPrecision:(BOOL)allowLowPrecision + error:(NSError**)error; + +@end + +NS_ASSUME_NONNULL_END diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h new file mode 100644 index 0000000000000000000000000000000000000000..35cc2ca10a569ec0519da1709d5c8eaddb70aebe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h @@ -0,0 +1,19 @@ +#import + +#import + +NS_ASSUME_NONNULL_BEGIN + +@interface PTMCoreMLExecutor : NSObject + +@property(atomic, strong) MLModel* model; + +- (instancetype)initWithFeatureNames:(NSArray*)featureNames; + +- (void)setInputs:(c10::impl::GenericList)inputs; + +- (id)forward:(NSError**)error; + +@end + +NS_ASSUME_NONNULL_END diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h new file mode 100644 index 0000000000000000000000000000000000000000..f0ccd7280b09d7da1b8bb45631d31392a78cd31b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h @@ -0,0 +1,16 @@ +#import +#import + +NS_ASSUME_NONNULL_BEGIN + +@interface PTMCoreMLFeatureProvider : NSObject + +- (instancetype)initWithFeatureNames:(NSSet*)featureNames; + +- (void)clearInputTensors; + +- (void)setInputTensor:(const at::Tensor&)tensor forFeatureName:(NSString*)name; + +@end + +NS_ASSUME_NONNULL_END diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..5b4e37d4fa5a45eb50d55006c78bbc449f952a0f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h @@ -0,0 +1,41 @@ +#include +#include +#include + +namespace torch { +namespace jit { +namespace mobile { +namespace coreml { + +class MLModelWrapper : public CustomClassHolder { + public: + PTMCoreMLExecutor* executor; + std::vector outputs; + + MLModelWrapper() = delete; + + MLModelWrapper(PTMCoreMLExecutor* executor) : executor(executor) { + [executor retain]; + } + + MLModelWrapper(const MLModelWrapper& oldObject) { + executor = oldObject.executor; + outputs = oldObject.outputs; + [executor retain]; + } + + MLModelWrapper(MLModelWrapper&& oldObject) { + executor = oldObject.executor; + outputs = oldObject.outputs; + [executor retain]; + } + + ~MLModelWrapper() { + [executor release]; + } +}; + +} // namespace coreml +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h new file mode 100644 index 0000000000000000000000000000000000000000..5ad77cc3f3f89560dd3367d64db39595bf5c550e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h @@ -0,0 +1,26 @@ +#include +#import + +#include + +namespace torch::jit::mobile::coreml { + +struct TensorSpec { + std::string name; + c10::ScalarType dtype = c10::ScalarType::Float; +}; + +static inline c10::ScalarType scalar_type(const std::string& type_string) { + if (type_string == "0") { + return c10::ScalarType::Float; + } else if (type_string == "1") { + return c10::ScalarType::Double; + } else if (type_string == "2") { + return c10::ScalarType::Int; + } else if (type_string == "3") { + return c10::ScalarType::Long; + } + return c10::ScalarType::Undefined; +} + +} // namespace torch::jit::mobile::coreml diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..30a2ae6b8d80df266045f194c012124f03d211cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h @@ -0,0 +1,24 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +namespace torch::jit::xnnpack::delegate { + +class XNNCompiler { + public: + // Takes Flatbuffer Serialized XNNPack Model and rebuilds the xnn-subgraph + // returns an executor object that holds the xnn runtime object which we + // can then use to set inputs and run inference using the xnn graph. + static void compileModel( + const void* buffer_pointer, + size_t num_bytes, + XNNExecutor* executor); +}; + +} // namespace torch::jit::xnnpack::delegate diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..118af11d031fca3b52909976657e5029e7d19d6b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h @@ -0,0 +1,68 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +namespace torch::jit::xnnpack::delegate { + +class XNNExecutor { + private: + std::unique_ptr runtime_{ + nullptr, + &xnn_delete_runtime}; + std::vector input_ids_; + std::vector output_ids_; + std::vector externals_; + + public: + XNNExecutor() = default; + + template + bool set_inputs(std::vector& inputs, std::vector& outputs) { + externals_.clear(); + + if (inputs.size() != input_ids_.size()) { + return false; + } + + for (int i = 0; i < inputs.size(); i++) { + externals_.emplace_back(xnn_external_value{input_ids_[i], inputs[i]}); + } + + if (outputs.size() != output_ids_.size()) { + return false; + } + + for (int i = 0; i < outputs.size(); i++) { + externals_.emplace_back(xnn_external_value{output_ids_[i], outputs[i]}); + } + + return true; + } + + bool forward() { + xnn_status status = + xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()); + + if (status != xnn_status_success) { + return false; + } + + status = xnn_invoke_runtime(runtime_.get()); + + if (status != xnn_status_success) { + return false; + } + + return true; + } + + friend class XNNCompiler; +}; + +} // namespace torch::jit::xnnpack::delegate diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/serialization/serializer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/serialization/serializer.h new file mode 100644 index 0000000000000000000000000000000000000000..05594983d246b0e42c3597142ca9f482319efcb8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/serialization/serializer.h @@ -0,0 +1,89 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +using namespace fb_xnnpack; // Specified in the schema + +class XNNSerializer { + public: + // Constructors + // initial buffersize of 1024 which will grow + // automatically, constant buffer and buffer sizes initialized with dummy + // values as 0 index is reserved for non-constant tensors + XNNSerializer() : XNNSerializer(1024) {} + + explicit XNNSerializer(size_t bufferSize) + : _builder(bufferSize), + _nodes(), + _values(), + _constantBuffer({CreateBuffer( + _builder, + {})}), // index 0 is reserved for non-const data + _bufferSizes({0}) {} + + // Serializing Nodes + + // Serialize add node, we are serializing the argument needed to call + // xnn_define_add2. Serializing these values, and at run time we build + // teh graph by re running xnn_define_add2 + void serializeAddNode( + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + + // Serializing Values + void serializeTensorValue( + uint32_t xnn_datatype, + size_t num_dims, + std::vector dims, + size_t buffer_data_idx, + uint32_t external_id, + uint32_t flags, + uint32_t id_out); + + // finish and serialize xnngraph returning serialized data + std::string finishAndSerialize( + std::vector input_ids, + std::vector output_ids, + size_t num_extern_ids); + + // decoupled data serialization with tensor values. This way constant tensor + // data can be referenced by multiple intermediate tensors. This call + // serializes the num_bytes of the data_ptr and returns the index it was + // placed in. + size_t serializeData(const uint8_t* data_ptr, size_t num_bytes); + + private: + // xnnpack version we are serializing + const char* _version_sha1 = "ae108ef49aa5623b896fc93d4298c49d1750d9ba"; + + // flatbuffer objects we will create and serialize together to create xnngraph + flatbuffers_fbsource::FlatBufferBuilder _builder; + + // Vector of the serialized xnnpack nodes + std::vector> _nodes; + + // Vector of the serialized xnnpack values + std::vector> _values; + + std::vector> _constantBuffer; + std::vector _bufferSizes; +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..b018b4adbb889f7c3176f757625364899e66d266 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h @@ -0,0 +1,97 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +class XNNGraph { + private: + const float output_min = -std::numeric_limits::infinity(); + const float output_max = std::numeric_limits::infinity(); + + // serializer class + XNNSerializer _serializer; + // xnn subgraph + xnn_subgraph_t _subgraph_ptr; + // Set of all the tensor values throughout the jit graph + std::unordered_set _intermediate_tensors; + // Set of all the tensor values mapped to the xnnpack ids + std::unordered_map _val_to_ids; + // Vector containing the torch valued inputs/outputs, + // must be ordered to preserve the order of input/outputs + std::vector _inputs; + std::vector _outputs; + + // Graph passes for optimizing and tracing torchscript graph + // Essentially massaging the graph into a digestiable format for + // xnnpack graph lowering. + std::shared_ptr optimizeAndTraceGraph( + std::shared_ptr graph, + std::vector& example_inputs); + + // Gather all the intermediate tensor values within a graph. This + // skips through all prim constants. The purpose of this is for defining + // the tensor values beforehand for the xnnpack subgraph. + void gatherTensorValues(std::shared_ptr& graph); + + // Gathers the tensor values in a give node + void gatherNodeInputs(torch::jit::Node& node); + + // Helper function to determine if a jit value is a graph input + bool isGraphInput(torch::jit::Value* val); + + // Helper function to determine if a jit value is a graph output + bool isGraphOutput(torch::jit::Value* val); + + // Defines all xnnpack nodes for the nodes in the graph + void defineAllNodes(std::shared_ptr& graph); + + // Defines all xnn tensor values used throughout the graph + void defineAllTensorValues(); + + // Makes a pass through the graph and throws if any ops are unsupported + void checkOpsToDelegate(std::shared_ptr& graph); + + public: + XNNGraph() : _serializer(), _subgraph_ptr(nullptr) { + xnn_status status = xnn_initialize(/*allocator =*/nullptr); + TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack"); + } + + ~XNNGraph() { + xnn_deinitialize(); + if (_subgraph_ptr != nullptr) { + xnn_delete_subgraph(_subgraph_ptr); + } + } + + void buildXNNGraph( + std::shared_ptr& graph, + std::vector example_inputs); + + void runGraphOnInputs( + std::vector tensor_inputs, + std::vector tensor_outputs); + + std::string serializedXNNGraph(); + + std::vector> getGraphOutputShapes(); +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/cuda/interface.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/cuda/interface.h new file mode 100644 index 0000000000000000000000000000000000000000..926e4cb5d265c7a5d0851a8bceb6f12c76a754e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/cuda/interface.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include + +/* + * This file contains APIs for cuda fuser; + * + * We use an empty static struct to hold the function pointers, which are + * registered separately. This is to support cpu-only compilation. + * Registration is done in torch/csrc/jit/codegen/cuda/register_interface.cpp + */ + +namespace torch::jit::fuser::cuda { + +TORCH_API std::atomic& getCudaFusionGuardMode(); + +TORCH_API bool getSingletonFusion(); +TORCH_API bool setSingletonFusion(bool value); +TORCH_API bool getHorizontalFusion(); +TORCH_API bool setHorizontalFusion(bool value); + +// dummy struct to allow API registration +struct CudaFuserInterface { + void (*fn_compile_n)(Node*) = nullptr; + void (*fn_run_n_s)(const Node*, Stack&) = nullptr; + void (*fn_fuse_graph)(std::shared_ptr&) = nullptr; + bool (*fn_can_fuse_n)(const Node*) = nullptr; + void (*fn_insert_profile_inodes)(ProfilingRecord* pr) = nullptr; + bool (*fn_profile_n)(const Node*) = nullptr; + bool (*fn_skip_n)(const std::string&, bool flip) = nullptr; +}; + +// Get interface, this is used by registration and user facing API internally +TORCH_API CudaFuserInterface* getFuserInterface(); + +TORCH_API void compileFusionGroup(Node* fusion_node); +TORCH_API void runFusionGroup(const Node* fusion_node, Stack& stack); +TORCH_API void fuseGraph(std::shared_ptr&); +TORCH_API bool canFuseNode(const Node* node); +TORCH_API void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr); +TORCH_API bool profileNode(const Node* node); + +TORCH_API bool skipNode(const std::string& symbol_str, bool flip = true); + +TORCH_API bool isEnabled(); +TORCH_API bool setEnabled(bool is_enabled); +TORCH_API bool canBeEnabled(); + +} // namespace torch::jit::fuser::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/arg_spec.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/arg_spec.h new file mode 100644 index 0000000000000000000000000000000000000000..923aa324aa7ae9860a23f7ec4a9f3d6258387b2d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/arg_spec.h @@ -0,0 +1,55 @@ +#pragma once +#include +#include // fmap +#include +#include +#include + +#include +#include + +namespace torch::jit::fuser { + +// Describes the (runtime) arguments to a kernel. +// ArgSpecs are also used as keys to lookup instantiated kernels, so +// they are hashable. +// Note: the device to run on is included in the arg spec because kernels +// are compiled per-device. +struct TORCH_API ArgSpec { + ArgSpec(at::TensorList inputs, const int _device) + : descs_{c10::fmap(inputs)}, + hash_code_{c10::get_hash(_device, inputs.size(), descs_)}, + device_{_device} {} + + // (Common) hash function + static size_t hash(const ArgSpec& spec) { + return spec.hash_code_; + } + + // Comparators + bool operator==(const ArgSpec& other) const { + return (descs_ == other.descs_ && device_ == other.device_); + } + + bool operator!=(const ArgSpec& spec) const { + return !(*this == spec); + } + + // Getters + size_t hashCode() const { + return hash_code_; + } + const std::vector& descs() const { + return descs_; + } + int device() const { + return device_; + } + + private: + std::vector descs_; + size_t hash_code_; + int device_; +}; + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/codegen.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/codegen.h new file mode 100644 index 0000000000000000000000000000000000000000..cab5ccf0eb5e6f3caa49f87e4f30181a311946f5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/codegen.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace torch::jit::fuser { + +// Creates a CPU or CUDA kernel for the given graph. +// Returns the C++ or CUDA string implementing the kernel. +TORCH_API std::string generateKernel( + const std::string& name, + const Graph& graph, + const std::vector>>& + inputs, + const std::vector>& outputs, + const bool use_cuda); + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/compiler.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..1707e22f5ceb0c698f17b4f7a48a096f42660885 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/compiler.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace torch::jit::fuser { + +// Performs device-independent "upfront" compilation of the given fusion_group, +// if it has not been registered already. +// Returns a key that can be used to run the fusion later +TORCH_API int64_t registerFusion(const Node* fusion_group); + +// Performs device-specific "runtime" compilation of the given kernel +// with the runtime arguments specified in ArgSpec. +// Outputs are allocated using map_size on the specified device. +TORCH_API std::shared_ptr compileKernel( + const KernelSpec& spec, + const ArgSpec& arg_spec, + const std::vector& map_size, + const at::Device& device); + +TORCH_API size_t nCompiledKernels(); + +TORCH_API int debugFuser(); + +using FusedKernelConstructor = std::function( + int16_t device, + std::string name, + std::string code, + std::vector input_desc, + std::vector output_desc, + std::vector chunk_desc, + std::vector concat_desc, + bool has_random)>; + +TORCH_API void registerFusionBackend( + at::Device::Type backend_type, + FusedKernelConstructor ctor); +TORCH_API bool hasFusionBackend(at::Device::Type backend_type); +struct TORCH_API RegisterFusionBackend{RegisterFusionBackend( + at::Device::Type backend_type, + FusedKernelConstructor ctor){ + registerFusionBackend(backend_type, std::move(ctor)); +} // namespace torch::jit::fuser +} +; + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..25acde8ae4c81b7702043936e4e433a6a7374a16 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::jit::fuser::cpu { + +// Represents a compiled CPU kernel and the metadata necessary to run it +struct TORCH_API FusedKernelCPU : public FusedKernel { + FusedKernelCPU( + std::string name, + std::string code, + std::vector input_desc, + std::vector output_desc, + std::vector chunk_desc, + std::vector concat_desc, + bool has_random); + + at::Backend backend() const override { + return at::Backend::CPU; + } + + void launch_raw(const uint32_t numel, std::vector& arguments) + const override { + kernel(numel, arguments.data()); + } + + private: + std::unique_ptr so_lib; + void (*kernel)(uint32_t, void**) = nullptr; +}; + +} // namespace torch::jit::fuser::cpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/resource_strings.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/resource_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..134451f335f83a5de81f962a0b82a58a3d7d51c1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/resource_strings.h @@ -0,0 +1,101 @@ +#pragma once + +#include + +namespace torch::jit::fuser::cpu { + +/*with type_as not checking type of its input, a fusion group can have non-fp32 +tensor as input. Correct code for this case is generated, however, nvrtc does +not know how to handle int*_t integer types, so typedefs help it handle those +cases*/ + +static auto type_declarations_template = at::jit::CodeTemplate(R"( + +#define POS_INFINITY INFINITY +#define NEG_INFINITY -INFINITY + +typedef ${IndexType} IndexType; +template +struct TensorInfo { + T* data; + IndexType sizes[N]; + IndexType strides[N]; +}; +template +struct TensorInfo { + T * data; +}; +)"); + +static auto cpu_compilation_unit_template = at::jit::CodeTemplate(R"( +#include +#include +#include + +double rsqrt(double x) { + return 1.0/sqrt(x); +} + +float rsqrtf(float x) { + return 1.0f/sqrtf(x); +} + +double frac(double x) { + return x - trunc(x); +} + +float fracf(float x) { + return x - truncf(x); +} + +${type_declarations} + +#ifdef _MSC_VER +template struct int_of_size; + +#define DEFINE_INT_OF_SIZE(int_t) \ +template<> struct int_of_size { using type = int_t; } + +DEFINE_INT_OF_SIZE(int64_t); +DEFINE_INT_OF_SIZE(int32_t); +DEFINE_INT_OF_SIZE(int16_t); +DEFINE_INT_OF_SIZE(int8_t); + +#undef DEFINE_INT_OF_SIZE + +template +using int_same_size_t = typename int_of_size::type; + +#define IndexTypeLoop int_same_size_t +#define ToIndexTypeLoop(x) static_cast(x) +#else +#define IndexTypeLoop IndexType +#define ToIndexTypeLoop(x) x +#endif + +#define OMP_THRESHOLD 100000 +static void ${kernelName}_kernel(IndexType totalElements, ${formals}) { + #pragma omp parallel for if(totalElements > OMP_THRESHOLD) + for (IndexTypeLoop linearIndex = 0; + linearIndex < ToIndexTypeLoop(totalElements); + linearIndex += 1) { + // Convert `linearIndex` into an offset of tensor: + ${tensorOffsets} + // calculate the results + ${kernelBody} + } +} + +#ifdef _WIN32 +#define JIT_API __declspec(dllexport) +#else +#define JIT_API +#endif + +extern "C" +JIT_API void ${kernelName}(IndexType totalElements, void ** args) { + ${kernelName}_kernel(totalElements ${,argument_loads}); +} +)"); + +} // namespace torch::jit::fuser::cpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/temp_file.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/temp_file.h new file mode 100644 index 0000000000000000000000000000000000000000..dd21f7573f34e1cf7fcd81352b0bb76aff837d76 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cpu/temp_file.h @@ -0,0 +1,135 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#else +#include +#endif + +#include +#include + +namespace torch::jit::fuser::cpu { + +#ifdef _MSC_VER +inline int wmkstemps(wchar_t* tmpl, int suffix_len) { + int len; + wchar_t* name; + int fd = -1; + int save_errno = errno; + + len = wcslen(tmpl); + if (len < 6 + suffix_len || + wcsncmp(&tmpl[len - 6 - suffix_len], L"XXXXXX", 6)) { + return -1; + } + + name = &tmpl[len - 6 - suffix_len]; + + std::random_device rd; + do { + for (unsigned i = 0; i < 6; ++i) { + name[i] = "abcdefghijklmnopqrstuvwxyz0123456789"[rd() % 36]; + } + + fd = _wopen(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD); + } while (errno == EEXIST); + + if (fd >= 0) { + errno = save_errno; + return fd; + } else { + return -1; + } +} +#endif + +struct TempFile { + AT_DISALLOW_COPY_AND_ASSIGN(TempFile); + + TempFile(const std::string& t, int suffix) { +#ifdef _MSC_VER + auto wt = c10::u8u16(t); + std::vector tt(wt.c_str(), wt.c_str() + wt.size() + 1); + int fd = wmkstemps(tt.data(), suffix); + AT_ASSERT(fd != -1); + file_ = _wfdopen(fd, L"r+"); + auto wname = std::wstring(tt.begin(), tt.end() - 1); + name_ = c10::u16u8(wname); +#else + // mkstemps edits its first argument in places + // so we make a copy of the string here, including null terminator + std::vector tt(t.c_str(), t.c_str() + t.size() + 1); + int fd = mkstemps(tt.data(), suffix); + AT_ASSERT(fd != -1); + file_ = fdopen(fd, "r+"); + // - 1 because tt.size() includes the null terminator, + // but std::string does not expect one + name_ = std::string(tt.begin(), tt.end() - 1); +#endif + } + + const std::string& name() const { + return name_; + } + + void sync() { + fflush(file_); + } + + void write(const std::string& str) { + size_t result = fwrite(str.c_str(), 1, str.size(), file_); + AT_ASSERT(str.size() == result); + } + +#ifdef _MSC_VER + void close() { + if (file_ != nullptr) { + fclose(file_); + } + file_ = nullptr; + } +#endif + + FILE* file() { + return file_; + } + + ~TempFile() { +#ifdef _MSC_VER + if (file_ != nullptr) { + fclose(file_); + } + auto wname = c10::u8u16(name_); + if (!wname.empty() && _waccess(wname.c_str(), 0) != -1) { + _wunlink(wname.c_str()); + } +#else + if (file_ != nullptr) { + // unlink first to ensure another mkstemps doesn't + // race between close and unlink + unlink(name_.c_str()); + fclose(file_); + } +#endif + } + + private: + FILE* file_ = nullptr; + std::string name_; +}; + +} // namespace torch::jit::fuser::cpu diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d635049e758a2cec274c834e0a7eede9d4d8721f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace torch::jit::fuser::cuda { + +// query codegen output arch and target +TORCH_CUDA_CU_API void codegenOutputQuery( + const cudaDeviceProp* const prop, + int& major, + int& minor, + bool& compile_to_sass); + +// A class holding metadata for an actual CUDA function. +// Note: CUDA functions are per device. +struct TORCH_CUDA_CU_API FusedKernelCUDA + : public ::torch::jit::fuser::FusedKernel { + FusedKernelCUDA( + at::DeviceIndex device, + std::string name, + std::string code, + std::vector input_desc, + std::vector output_desc, + std::vector chunk_desc, + std::vector concat_desc, + bool has_random); + + ~FusedKernelCUDA() override; + + void launch_raw(const uint32_t numel, std::vector& arguments) + const override; + + at::Backend backend() const override { + return at::Backend::CUDA; + } + + private: + static constexpr auto kBlockSize = 128; + + // Note: per device to store device properties and compute launch heuristics + // Acquiring these values at launch time would be too slow + at::DeviceIndex device_; + int maxBlocks_{}; + cudaDeviceProp* prop_{}; + std::vector ptx_; + CUmodule module_{}; + CUfunction function_{}; +}; + +} // namespace torch::jit::fuser::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..ff2ef1f2377ce1c55ba7f25e31f4c3c7e2d5a252 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -0,0 +1,405 @@ +#pragma once + +#include +#include + +namespace torch::jit::fuser::cuda { + +/*with type_as not checking type of its input, a fusion group can have non-fp32 +tensor as input. Correct code for this case is generated, however, nvrtc does +not know how to handle int*_t integer types, so typedefs help it handle those +cases*/ + +static constexpr auto bfloat16_type_string = "__nv_bfloat16"; + +#if defined(USE_ROCM) +static auto type_declarations_template = at::jit::CodeTemplate(R"( +${HalfHeader} +${BFloat16Header} +${RandHeader} + +#define NAN __int_as_float(0x7fffffff) +#define POS_INFINITY __int_as_float(0x7f800000) +#define NEG_INFINITY __int_as_float(0xff800000) + +typedef ${IndexType} IndexType; +template +struct TensorInfo { + T* data; + IndexType sizes[N]; + IndexType strides[N]; +}; +template +struct TensorInfo { + T * data; +}; +)"); +#else +static auto type_declarations_template = at::jit::CodeTemplate(R"( +typedef unsigned char uint8_t; +typedef signed char int8_t; +typedef short int int16_t; +typedef long long int int64_t; +typedef unsigned long long int uint64_t; +${HalfHeader} +${BFloat16Header} +${RandHeader} + +#define NAN __int_as_float(0x7fffffff) +#define POS_INFINITY __int_as_float(0x7f800000) +#define NEG_INFINITY __int_as_float(0xff800000) + +typedef ${IndexType} IndexType; +template +struct TensorInfo { + T* data; + IndexType sizes[N]; + IndexType strides[N]; +}; +template +struct TensorInfo { + T * data; +}; +)"); +#endif + +// We rewrite the code for philox RNG from curand as nvrtc couldn't resolve the +// curand header correctly. +constexpr auto rand_support_literal = R"( + + class Philox { + public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + counter = make_uint4(0, 0, 0, 0); + counter.z = (unsigned int)(subsequence); + counter.w = (unsigned int)(subsequence >> 32); + STATE = 0; + incr_n(offset / 4); + } + + __device__ inline unsigned long operator()() { + if(STATE == 0) { + uint4 counter_ = counter; + uint2 key_ = key; + for(int i = 0; i < 9; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); key_.y += (kPhilox10B); + } + output = single_round(counter_, key_); + incr(); + } + unsigned long ret; + switch(STATE) { + case 0: ret = output.x; break; + case 1: ret = output.y; break; + case 2: ret = output.z; break; + case 3: ret = output.w; break; + } + STATE = (STATE + 1) % 4; + return ret; + } + + private: + uint4 counter; + uint4 output; + uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + __device__ inline void incr() { + if (++counter.x) + return; + if (++counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + unsigned int *result_high) { + *result_high = __umulhi(a, b); + return a*b; + } + + __device__ inline uint4 single_round(uint4 ctr, uint2 key) { + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; + }; + + // Inverse of 2^32. + #define M_RAN_INVM32 2.3283064e-10f + __device__ __inline__ float uniform(unsigned int x) { + return x * M_RAN_INVM32; + } +)"; + +constexpr auto rand_param = + ",unsigned long long seed, unsigned long long offset"; + +constexpr auto rand_init = R"( + int idx = blockIdx.x*blockDim.x + threadIdx.x; + Philox rnd(seed, idx, offset); +)"; + +static auto cuda_compilation_unit_template = at::jit::CodeTemplate(R"( +${type_declarations} + +extern "C" __global__ +void ${kernelName}(IndexType totalElements, ${formals} ${RandParam}) { + ${RandInit} + // check whether do vectorized load/store and allocate buffer + bool flag_vec4 = true; + ${tensorChecks} + if (flag_vec4) { + for (IndexType linearIndex = 4 * (blockIdx.x * blockDim.x + threadIdx.x); + linearIndex < totalElements; + linearIndex += 4 * gridDim.x * blockDim.x) { + // Convert `linearIndex` into an offset of tensor as it is: + ${tensorOffsets} + // load 4 at a time + ${kernelLoad} + #pragma unroll 4 + for (int i=0; i<4; i++) { + // calculate the results + ${kernelBody_vec4} + } + // store 4 at a time + ${kernelStore} + } + } else { + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x) { + // Convert `linearIndex` into an offset of tensor: + ${tensorOffsets} + // calculate the results + ${kernelBody} + } + } +} +)"); + +// This snippet enables half support in the jit. Following the pattern for +// reductions, fp16 input data is immediately upconverted to float +// with __half2float(). All mathematical operations are done on float +// values, and if needed the intermediate float representation is +// converted to half with __float2half() when writing to a half tensor. +#if defined(USE_ROCM) +constexpr auto half_support_literal = + R"( +typedef __half half; +)"; +#else +constexpr auto half_support_literal = + R"( +#define __HALF_TO_US(var) *(reinterpret_cast(&(var))) +#define __HALF_TO_CUS(var) *(reinterpret_cast(&(var))) +#if defined(__cplusplus) + struct __align__(2) __half { + __host__ __device__ __half() { } + + protected: + unsigned short __x; + }; + + /* All intrinsic functions are only available to nvcc compilers */ + #if defined(__CUDACC__) + /* Definitions of intrinsics */ + __device__ __half __float2half(const float f) { + __half val; + asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f)); + return val; + } + + __device__ float __half2float(const __half h) { + float val; + asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h))); + return val; + } +)" + // MSVC's preprocessor (but not the standard compiler) has a bug + // where it incorrectly tokenizes raw string literals, ending when it sees a + // " this causes the #endif in this string literal to be treated as a + // preprocessor token which, in turn, cause sccache on windows CI to fail. + // See https://godbolt.org/z/eVTIJq as an example. + // This workaround uses string-pasting to separate the " and the #endif into + // different strings + R"( + #endif /* defined(__CUDACC__) */ +#endif /* defined(__cplusplus) */ +#undef __HALF_TO_US +#undef __HALF_TO_CUS + +typedef __half half; +)"; +#endif + +#if defined(USE_ROCM) +constexpr auto bfloat16_support_literal = + R"( +#ifndef __align__ +#define __align__(x) __attribute__((aligned(x))) +#endif + +typedef struct __align__(2) { + unsigned short x; +} +__nv_bfloat16_raw; + +#if defined(__cplusplus) +struct __align__(2) __nv_bfloat16 { + __host__ __device__ __nv_bfloat16() {} + + __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { + __x = hr.x; + return *this; + } + + unsigned short __x; +}; + +__device__ unsigned short __internal_float2bfloat16( + const float f, + unsigned int& sign, + unsigned int& remainder) { + unsigned int x; + + x = __float_as_uint(f); + + if ((x & 0x7fffffffU) > 0x7f800000U) { + sign = 0U; + remainder = 0U; + return static_cast(0x7fffU); + } + sign = x >> 31; + remainder = x << 16; + return static_cast(x >> 16); +} + +/* Definitions of intrinsics */ +__device__ __nv_bfloat16 __float2bfloat16(const float a) { + __nv_bfloat16 val; + __nv_bfloat16_raw r; + unsigned int sign; + unsigned int remainder; + r.x = __internal_float2bfloat16(a, sign, remainder); + if ((remainder > 0x80000000U) || + ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { + r.x++; + } + val = r; + return val; +} + +__device__ float __bfloat162float(const __nv_bfloat16 a) { + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(a.__x) << 16}; + return u.fp32; +} +#endif /* defined(__cplusplus) */ +)"; +#else +constexpr auto bfloat16_support_literal = + R"( +#define __BFLOAT16_TO_US(var) *(reinterpret_cast(&(var))) +#define __BFLOAT16_TO_CUS(var) \ + *(reinterpret_cast(&(var))) + +typedef struct __align__(2) { + unsigned short x; +} +__nv_bfloat16_raw; + +#if defined(__cplusplus) +struct __align__(2) __nv_bfloat16 { + __host__ __device__ __nv_bfloat16() {} + + __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { + __x = hr.x; + return *this; + } + + protected: + unsigned short __x; +}; + +#if defined(__CUDACC__) +__device__ unsigned short __internal_float2bfloat16( + const float f, + unsigned int& sign, + unsigned int& remainder) { + unsigned int x; + + x = __float_as_uint(f); + + if ((x & 0x7fffffffU) > 0x7f800000U) { + sign = 0U; + remainder = 0U; + return static_cast(0x7fffU); + } + sign = x >> 31; + remainder = x << 16; + return static_cast(x >> 16); +} + +/* Definitions of intrinsics */ +__device__ __nv_bfloat16 __float2bfloat16(const float a) { + __nv_bfloat16 val; +#if __CUDA_ARCH__ >= 800 + asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(__BFLOAT16_TO_US(val)) : "f"(a)); +#else + __nv_bfloat16_raw r; + unsigned int sign; + unsigned int remainder; + r.x = __internal_float2bfloat16(a, sign, remainder); + if ((remainder > 0x80000000U) || + ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { + r.x++; + } + val = r; +#endif + return val; +} + +__device__ float __bfloat162float(const __nv_bfloat16 a) { + float val; + asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(__BFLOAT16_TO_CUS(a))); + return val; +} +#endif /* defined(__CUDACC__) */ +#endif /* defined(__cplusplus) */ +#undef __BFLOAT16_TO_US +#undef __BFLOAT16_TO_CUS +)"; +#endif + +} // namespace torch::jit::fuser::cuda diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/executor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/executor.h new file mode 100644 index 0000000000000000000000000000000000000000..188a25ed8cc65cf1c03ffc1f73b2ea265fa7af9e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/executor.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch::jit::fuser { + +// Runs the fusion associated with the key (see registerFusion() in interface.h) +// on the inputs taken from the given Stack. +TORCH_API bool runFusion( + const int64_t key, + Stack& stack, + std::string* code_out = nullptr); + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/fallback.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/fallback.h new file mode 100644 index 0000000000000000000000000000000000000000..af0ff32641ce65fe59e31574eb56075bf7c806fc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/fallback.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +#include + +namespace torch::jit::fuser { + +void runFallback(int64_t key, Stack& stack); + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/fused_kernel.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/fused_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..de00904a749c8e062ce91b6d0b4cc3b8193ed9e4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/fused_kernel.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::jit::fuser { + +struct FusedKernel { + AT_DISALLOW_COPY_AND_ASSIGN(FusedKernel); + + FusedKernel( + std::string name, + std::string code, + std::vector input_desc, + std::vector output_desc, + std::vector chunk_desc, + std::vector concat_desc, + bool has_random) + : name_(std::move(name)), + code_(std::move(code)), + input_desc_(std::move(input_desc)), + output_desc_(std::move(output_desc)), + chunk_desc_(std::move(chunk_desc)), + concat_desc_(std::move(concat_desc)), + has_random_(has_random) {} + + virtual ~FusedKernel() = default; + + // arguments is a list of pointers to the arguments for the compiled CUDA/CPU + // code. + // The format of arguments is suitable for directly passing to a call to + // cuLaunchKernel as the kernel arguments. + // Currently the first argument is a pointer to numel (for passing to + // CUDA code), and the remainder are pointers to the TensorInfo structs + // that compiled code uses to load Tensor data. + // launch_with_tensors handles packing at::Tensors into this arguments array. + // CPU code uses the same convension so that launch_with_tensors can be + // shared. + virtual void launch_raw(const uint32_t numel, std::vector& arguments) + const = 0; + virtual at::Backend backend() const = 0; + + // Getters + const std::string& name() const { + return name_; + } + const std::string& code() const { + return code_; + } + const std::vector& inputDesc() const { + return input_desc_; + } + const std::vector& outputDesc() const { + return output_desc_; + } + const std::vector& chunkDesc() const { + return chunk_desc_; + } + const std::vector& concatDesc() const { + return concat_desc_; + } + bool hasRandom() const { + return has_random_; + } + + protected: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const std::string name_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const std::string code_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const std::vector input_desc_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const std::vector output_desc_; + + // same size as input_desc, describes whether an + // input should be broken into subtensors (chunks) + // to be consumed by the fusion group + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const std::vector chunk_desc_; + + // same size as output_desc, describes whether + // an output is actually a concatenation of + // many subtensors that the fusion group produces + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const std::vector concat_desc_; + + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const bool has_random_; +}; + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/interface.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/interface.h new file mode 100644 index 0000000000000000000000000000000000000000..977e90191160ce7991d9b06c2a0d0c0952742b3b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/interface.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::jit { + +constexpr int kCPUDevice = -1; + +// Assigns a "key" to the given fusion_group that it can use to run its +// fusion later (via runFusion() below). +TORCH_API int64_t registerFusion(const Node* fusion_group); + +// Runs the fusion corresponding to the given key on the inputs +// found on the stack. Outputs are placed on the same stack. +// In some cases a fusion cannot be run and a fallback path where +// PyTorch's interpreter runs the graph instead is attempted. +TORCH_API void runFusion(const int64_t key, Stack& stack); + +// True if the respective devices can fuse, false otherwise +TORCH_API bool canFuseOnCPU(); +TORCH_API bool canFuseOnGPU(); + +// Sets whether fusion on the CPU is allowed (disabled by default due to +// flakiness) +TORCH_API void overrideCanFuseOnCPU(bool value); + +// Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval +TORCH_API void overrideMustUseLLVMOnCPU(bool value); + +// Sets whether fusion on the GPU is allowed (enabled by default) +TORCH_API void overrideCanFuseOnGPU(bool value); + +// Treats the given graph as a fusion group and launches it on the +// specified device with the given inputs. +// Returns the outputs. +TORCH_API std::vector debugLaunchGraph( + Graph& graph, + at::ArrayRef inputs); + +// Treats the given graph as a fusion group and returns the generated code. +TORCH_API std::string debugGetFusedKernelCode( + Graph& graph, + at::ArrayRef inputs); + +TORCH_API size_t nCompiledKernels(); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/kernel_cache.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/kernel_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..370f782453b1cbf116ff5cdee5955124a04e3961 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/kernel_cache.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace torch::jit::fuser { + +// A thread-safe cache interface. + +// Normalizes the graph by canonicalizing and erasing shape information +TORCH_API std::shared_ptr normalizeGraphForCache( + const std::shared_ptr& graph); + +// Stores the given graph, returning the key used to access it +TORCH_API int64_t store(std::shared_ptr graph); + +// Given a graph, find a KernelSpec based on it +TORCH_API std::optional lookupGraph( + const std::shared_ptr& graph); + +// Returns the graph corresponding to the given key (if it exists) +TORCH_API std::optional retrieve(const int64_t key); + +// Returns the size of the fusion key -> KernelSpec cache. +// Only used for testing. +TORCH_API int64_t debugNumCachedKernelSpecs(); + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/kernel_spec.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/kernel_spec.h new file mode 100644 index 0000000000000000000000000000000000000000..4f7159af2a45f0607342db5a78c0528c77c3b5ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/kernel_spec.h @@ -0,0 +1,144 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::jit::fuser { + +// Helper struct containing partition information: the number of tensors +// created and the dimension the partitioning is performed on. +// Note: created during upfront compilation, once the tensors are known +// at runtime the partition info is logically combined with the tensor +// descriptions to create PartitionDesc objects. +struct TORCH_API PartitionInfo { + PartitionInfo(const int64_t _nSubTensors, const int64_t _dim) + : nSubTensors_{_nSubTensors}, dim_{_dim} {} + + int64_t nSubTensors() const { + return nSubTensors_; + } + int64_t dim() const { + return dim_; + } + + private: + int64_t nSubTensors_; + int64_t dim_; +}; + +// "Kernel Specification." - Contains device-independent fusion information. +// Each kernel specification contains a map of instantiated generated functions +// that implement some or most of its functionality. Multiple generated +// functions are needed by each abstract specification because of different +// devices (cpu vs gpu, different gpus) and different inputs (int vs float, +// contiguous vs discontiguous). +// Note: uses a mutex to control access to its kernel store +// Note: unordered containers do not invalidate references/pointers on +// rehashing, which is critical for thread-safety. +// TODO: allow abstract kernels to use multiple generated kernels +// TODO: allow abstract kernels to reuse generated kernels from common pool +struct TORCH_API KernelSpec { + // Note: assumes the spec is a single block + // Note: This is the appropriate place to generalize if you want to add other + // passes to upfront compilation that walk the graph. + KernelSpec(const int64_t _key, const std::shared_ptr& _graph) + : key_{_key}, + graph_{_graph}, + code_{_graph, ""}, + nInputs_{_graph->inputs().size()} + + { + // No need to iterate over reference since n is pointer + for (const auto n : graph_->nodes()) { + static_assert(std::is_pointer_v, "n must be a pointer"); + if (n->kind() == aten::rand_like) { + has_random_ = true; + break; + } + } + nTensorInputs_ = std::count_if( + graph_->inputs().begin(), graph_->inputs().end(), [](const Value* v) { + return v->type()->isSubtypeOf(*TensorType::get()); + }); + } + + // Getters + int64_t key() const { + return key_; + } + std::shared_ptr graph() const { + return graph_; + } + const Code& code() const { + return code_; + } + int64_t nInputs() const { + return nInputs_; + } + int64_t nTensorInputs() const { + return nTensorInputs_; + } + + std::vector>& inputBroadcastGroups() { + return inputBroadcastGroups_; + } + const std::vector>& inputBroadcastGroups() const { + return inputBroadcastGroups_; + } + + std::vector& inputChunks() { + return inputChunks_; + } + const std::vector& inputChunks() const { + return inputChunks_; + } + + bool hasRandom() const { + return has_random_; + } + + // Cache functions + std::optional> findKernel( + const ArgSpec& arg_spec) const { + std::lock_guard guard{mutex_}; + const auto it = kernels_.find(arg_spec); + if (it == kernels_.end()) + return std::nullopt; + return it->second; + } + void cacheKernel( + const ArgSpec& arg_spec, + const std::shared_ptr& kernel) const { + std::lock_guard guard{mutex_}; + kernels_.emplace(arg_spec, kernel); + } + + private: + int64_t key_; + std::shared_ptr graph_; + Code code_; + uint64_t nInputs_; + uint64_t nTensorInputs_{}; + std::vector> inputBroadcastGroups_; + std::vector inputChunks_; + bool has_random_{false}; + mutable std::mutex mutex_; + mutable std:: + unordered_map, c10::hash> + kernels_; +}; + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/partition_desc.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/partition_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..964e1821364a505c4fd2d93a1afeaf88faa5e571 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/partition_desc.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace torch::jit::fuser { + +// Descriptor for chunk-ing an input tensor into subtensors +// OR concat-ing an output tensor from subtensors +// Note: default constructed used for tensors that do not participate in +// chunk or cat operations. +struct TORCH_API PartitionDesc { + PartitionDesc() : nSubTensors_{1}, dim_{0} {} + + PartitionDesc(const TensorDesc& _desc, size_t _nSubTensors, size_t _dim) + : nSubTensors_{_nSubTensors}, dim_{_dim} { + AT_ASSERT(nSubTensors_ > 1); + std::vector cont = _desc.contiguity; + if (dim_ > 0) { + // when we narrow the concatenated output/chunked input + // we make the size[dim] smaller while keeping the stride[dim] the same, + // meaning: stride[dim - 1] != stride[dim]*size[dim] + // so dim - 1 is no longer contiguous + cont[dim_ - 1] = false; + } + subTensorDesc_ = std::make_shared(_desc.scalar_type, cont); + } + + bool isNoop() const { + return (nSubTensors_ == 1); + } + size_t nSubTensors() const { + return nSubTensors_; + } + size_t dim() const { + return dim_; + } + std::shared_ptr subTensorDesc() { + return subTensorDesc_; + } + const std::shared_ptr subTensorDesc() const { + return subTensorDesc_; + } + + private: + size_t nSubTensors_; // == 1 for tensors that should not be operated on via + // chunk/cat + size_t dim_; // dimension along which the chunk/concat occurs + std::shared_ptr + subTensorDesc_; // descriptor for the subtensor, if it exists +}; + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/tensor_desc.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/tensor_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..0c5db65d54ad1a2b04837d1e84e4c2e5a056e9b3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/tensor_desc.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::jit::fuser { + +// type information needed by the compiler for input/outputs +// contiguity[i] is true if the dim i is contiguous with dim i + 1. +// contiguity.back() == true means strides.back() == 1. +struct TORCH_API TensorDesc { + at::ScalarType scalar_type; + std::vector contiguity; + + TensorDesc(const at::ScalarType& type, const std::vector& contiguity) + : scalar_type{type}, contiguity{contiguity} { + if (contiguity.empty()) { + nDim_ = 0; + } else { + nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + + (lastIsContiguous() ? 1 : 0); + } + } + + // Delegating constructors + TensorDesc( + const at::ScalarType& type, + const at::IntArrayRef& sizes, + const at::IntArrayRef& strides) + : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} + + TensorDesc(const at::Tensor& t) + : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {} + + TensorDesc(const c10::TensorTypePtr& type) + : TensorDesc( + type->scalarType().value(), + type->sizes().concrete_sizes().value(), + type->strides().concrete_sizes().value()) {} + + // number of dimensions after contiguity compression + size_t nDim() const { + return nDim_; + } + + // True iff innermost stride is 1 + bool lastIsContiguous() const { + return (contiguity.empty() || contiguity.back()); + } + + static std::vector findContiguous( + const at::IntArrayRef& sizes, + const at::IntArrayRef& strides) { + AT_ASSERT(sizes.size() == strides.size()); + std::vector cont(sizes.size()); + for (size_t i = 0; i < sizes.size(); ++i) { + const auto expected_stride = + (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1; + cont[i] = (strides[i] == expected_stride); + } + return cont; + } + + bool operator==(const TensorDesc& desc) const { + return scalar_type == desc.scalar_type && contiguity == desc.contiguity; + } + + bool operator!=(const TensorDesc& desc) const { + return !(*this == desc); + } + + static size_t hash(const TensorDesc& spec) { + return c10::get_hash( + spec.scalar_type, + spec.nDim_, + std::hash>{}(spec.contiguity)); + } + + private: + size_t nDim_; +}; + +inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) { + out << d.scalar_type << "["; + for (const auto b : d.contiguity) + out << b << ";"; + out << "]"; + return out; +} + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/tensor_info.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/tensor_info.h new file mode 100644 index 0000000000000000000000000000000000000000..77a0d8bacdf23b21959bfad2cc29ef8cc715c7a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/fuser/tensor_info.h @@ -0,0 +1,24 @@ +#pragma once +#include + +#include +#include + +namespace torch::jit::fuser { + +// Host-side view of TensorInfo +// Note dims[0] - we need to dynamically allocate the dims. +struct TORCH_API TensorInfo { + uint32_t* sizes(size_t nDim) { + return &sizes_strides[0]; + } + uint32_t* strides(size_t nDim) { + return &sizes_strides[nDim]; + } + + void* data; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + uint32_t sizes_strides[0]; +}; + +} // namespace torch::jit::fuser diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..24be190ec5383bb42d75cf0b2a9e94f4c8b4713b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h @@ -0,0 +1,272 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::jit::fuser::onednn { + +// Engine represents a device and its context. From the device kind, the engine +// knows how to generate code for the target device and what kind of device +// object to be expected. The device id ensures that there is a unique engine +// being created for each device. The device handle passed from PyTorch allows +// oneDNN Graph implementation to work on the device specified by PyTorch, which +// is currently CPU, so we only have one engine. +// Ref: +// https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onednn/source/graph/programming_model#engine +struct Engine { + // CPU engine singleton + static dnnl::engine& getEngine(); + Engine(const Engine&) = delete; + void operator=(const Engine&) = delete; +}; + +// Stream is the logical abstraction for execution units. It is created on top +// of oneDNN Graph engine. A compiled oneDNN Graph partition is submitted to a +// stream for execution. +struct Stream { + // CPU stream singleton + static dnnl::stream& getStream(); + Stream(const Stream&) = delete; + void operator=(const Stream&) = delete; +}; + +struct LlgaTensorDesc { + using desc = dnnl::graph::logical_tensor; + + LlgaTensorDesc( + size_t tid, + std::vector sizes, + std::vector strides, + desc::data_type dtype, + desc::property_type property_type) + : tid_(tid), + sizes_(std::move(sizes)), + strides_(std::move(strides)), + dtype_(dtype), + property_type_(property_type), + layout_type_(desc::layout_type::strided), + layout_id_(-1) {} + + LlgaTensorDesc(const desc& t) + : tid_(t.get_id()), + sizes_(t.get_dims()), + strides_({-1}), + dtype_(t.get_data_type()), + property_type_(t.get_property_type()), + layout_type_(t.get_layout_type()), + layout_id_(-1) { + if (is_opaque()) { + layout_id_ = t.get_layout_id(); + } + if (is_strided()) { + strides_ = t.get_strides(); + } + } + + LlgaTensorDesc(const torch::jit::Value* v) + : LlgaTensorDesc( + v->unique(), + {}, + {}, + desc::data_type::f32, + get_property_type(v)) { + if (v->type()->isSubtypeOf(TensorType::get())) { + auto tt = v->type()->cast(); + + if (tt->scalarType()) { + dtype_ = getLlgaDataType(tt->scalarType().value()); + } + + auto sizes = tt->sizes(); + if (sizes.sizes()) { + for (auto d : *sizes.sizes()) { + sizes_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM)); + } + } + + auto strides = tt->strides(); + if (strides.sizes()) { + for (auto d : *strides.sizes()) { + strides_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM)); + } + } + } + } + + LlgaTensorDesc supplementTensorInfo(const at::Tensor& t) const; + + desc::data_type getLlgaDataType(at::ScalarType dt) const; + + at::ScalarType aten_scalar_type() const; + + const std::vector& sizes() const { + return sizes_; + } + + const std::vector& strides() const { + TORCH_CHECK(!is_opaque(), "Cannot get strides on opaque layout"); + return strides_; + } + + size_t tid() const { + return tid_; + } + + LlgaTensorDesc tid(uint64_t new_id) const { + auto ret = *this; + ret.tid_ = new_id; + return ret; + } + + desc::data_type dtype() const { + return dtype_; + } + + LlgaTensorDesc dtype(desc::data_type new_dtype) const { + return LlgaTensorDesc(tid_, sizes_, strides_, new_dtype, property_type_); + } + + desc::layout_type layout_type() const { + return layout_type_; + } + + LlgaTensorDesc layout_type(desc::layout_type new_layout_type) { + auto ret = *this; + ret.layout_type_ = new_layout_type; + return ret; + } + + desc::property_type get_property_type(const torch::jit::Value* v) { + switch (v->node()->kind()) { + case prim::Constant: + return desc::property_type::constant; + default: + return desc::property_type::variable; + } + } + + LlgaTensorDesc any() { + return layout_type(desc::layout_type::any); + } + + size_t storage_size() const { + return logical_tensor().get_mem_size(); + } + + desc logical_tensor() const { + if (is_dimensionality_unknown()) { + return desc( + tid_, dtype_, DNNL_GRAPH_UNKNOWN_NDIMS, layout_type_, property_type_); + } else if (is_opaque()) { + return desc(tid_, dtype_, sizes_, layout_id_, property_type_); + } else if (is_any()) { + return desc(tid_, dtype_, sizes_, layout_type_, property_type_); + } else { + return desc(tid_, dtype_, sizes_, strides_, property_type_); + } + } + + bool is_strided() const { + return layout_type_ == desc::layout_type::strided; + } + + bool is_any() const { + return layout_type_ == desc::layout_type::any; + } + + bool is_opaque() const { + return layout_type_ == desc::layout_type::opaque; + } + + bool operator==(const LlgaTensorDesc& desc) const { + return tid_ == desc.tid_ && sizes_ == desc.sizes_ && + dtype_ == desc.dtype_ && layout_type_ == desc.layout_type_ && + ((is_opaque() && layout_id_ == desc.layout_id_) || + strides_ == desc.strides_); + } + + bool operator!=(const LlgaTensorDesc& desc) const { + return (tid_ != desc.tid_) || (sizes_ != desc.sizes_) || + (dtype_ != desc.dtype_) || (layout_type_ != desc.layout_type_) || + !((is_opaque() && (layout_id_ == desc.layout_id_)) || + (strides_ == desc.strides_)); + } + + static size_t hash(const LlgaTensorDesc& desc) { + return c10::get_hash( + desc.tid_, + desc.sizes_, + desc.dtype_, + desc.layout_type_, + desc.layout_id_); + } + + void set_compute_inplace() { + compute_inplace_ = true; + } + + void set_input_tensor_index(size_t index) { + input_tensor_index_ = index; + } + + bool reuses_input_tensor() { + return compute_inplace_; + } + + size_t get_input_tensor_index() { + return input_tensor_index_; + } + + private: + bool is_dimensionality_unknown() const { + return sizes_.empty(); + } + + size_t tid_; + std::vector sizes_; + std::vector strides_; + desc::data_type dtype_; + desc::property_type property_type_; + desc::layout_type layout_type_; + size_t layout_id_; + // If this is an output tensor, and querying the compiled partition would + // determine that this tensor would reuse its input tensor, then + // compute_inplace would be true, and input_tensor_index would be the index of + // the corresponding input tensor in inputSpecs_ of the LlgaKernel object. + bool compute_inplace_ = false; + size_t input_tensor_index_{}; +}; + +// Initially, oneDNN Graph also used to have blocked layout for tensors between +// partitions, and the LlgaTensorImpl wrapper helped us bypass guard checks. +// oneDNN Graph has switched over to using strided tensors between partitions, +// but this wrapper still helps us bypass guard checks because the strides of +// tensors between partitions would be different from the ones the guard is +// otherwise expecting. +struct TORCH_API LlgaTensorImpl : public c10::TensorImpl { + LlgaTensorImpl( + at::Storage&& storage, + const caffe2::TypeMeta& data_type, + const LlgaTensorDesc& desc); + + const LlgaTensorDesc& desc() const { + return desc_; + } + + static at::Tensor llga_to_aten_tensor(LlgaTensorImpl* llgaImpl); + + private: + LlgaTensorDesc desc_; +}; + +at::Tensor empty_llga( + const LlgaTensorDesc& desc, + const c10::TensorOptions& options); + +dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor); + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/decompose_silu.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/decompose_silu.h new file mode 100644 index 0000000000000000000000000000000000000000..fc4f115f1bd23b49341a91c2320e0bcd7e31540a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/decompose_silu.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace torch::jit::fuser::onednn { + +void DecomposeSiluForLLGA(std::shared_ptr& graph); + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/defer_size_check.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/defer_size_check.h new file mode 100644 index 0000000000000000000000000000000000000000..e6d654199b2ffdde70f68b542f36f6efbc820cfb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/defer_size_check.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace torch::jit::fuser::onednn { + +void DeferSizeCheck(std::shared_ptr& graph); + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/graph_fuser.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/graph_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..d0a802e2734017b7f2b4a058c1b34f3fabb52ad3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/graph_fuser.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +namespace torch::jit::fuser::onednn { + +struct WorkBlock : public std::pair { + using pair::pair; + + Node* begin() { + return this->first; + } + Node* end() { + return this->second; + } +}; + +class GraphRewriter { + public: + GraphRewriter(Block* block, std::shared_ptr graph, AliasDb& aliasDb) + : block_(block), + graph_(std::move(graph)), + aliasDb_(aliasDb), + llgaHelper_(graph_) {} + + void cleanupSubgraphs(); + void buildupSubgraphs(); + + private: + Block* block_; + std::shared_ptr graph_; + AliasDb& aliasDb_; + LlgaGraphHelper llgaHelper_; + std::vector buildWorkBlocks(); + std::pair scanNode( + Node* consumer, + graph_node_list::iterator workblock_begin); + std::optional tryMerge(Node* consumer, Node* producer); +}; + +// This pass creates the subgraphs for oneDNN Graph Fusion Nodes. +// Its code-structure has been vastly inspired from +// torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +void CreateLlgaSubgraphs(std::shared_ptr& graph); + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/graph_helper.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/graph_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..bb817092877310d391ad934650a52627df8c9df3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/graph_helper.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::jit::fuser::onednn { + +#define STRIDED_LAYOUT 0 +#define OPAQUE_LAYOUT 1 + +struct OpPartitionMap { + void add(uint64_t opId, uint64_t partitionId) { + opmap_[opId] = partitionId; + } + void add(Node* n, uint64_t partitionId) { + add(Operator::getId(n), partitionId); + } + bool has(uint64_t opId) { + return opmap_.count(opId) > 0; + } + bool has(Node* n) { + return has(Operator::getId(n)); + } + uint64_t get(uint64_t opId) { + return opmap_[opId]; + } + uint64_t get(Node* n) { + auto opId = Operator::getId(n); + TORCH_CHECK( + has(opId), + "Node ", + n->kind().toQualString(), + " does not belong to any LLGA partition"); + return get(opId); + } + + private: + std::unordered_map opmap_; +}; + +class LlgaGraphHelper { + public: + LlgaGraphHelper( + const std::shared_ptr& graph, + dnnl::graph::partition::policy policy = + dnnl::graph::partition::policy::fusion); + + bool shouldMerge(Node* toMerge, Node* subgraph); + + bool shouldConsiderForMerge(Node* node); + + bool checkForSingleOpPartition(Node* node); + + Node* createSingletonSubgraph(Node* n, AliasDb& db); + + void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode, AliasDb& db); + + void unmergeIfAnyNodeIsMissing(Node* subgraphNode); + + static bool isLlgaSubgraph(const Node* node); + + Operator makeEltwiseOp(Node* node, dnnl::graph::op::kind kind); + + Operator makeBinaryOp(Node* node, dnnl::graph::op::kind kind); + + std::vector getPartitions() const; + + std::map getTensorIdToValue() const; + + Operator createOperator(Node* node); + + private: + size_t countSupportedOps(const std::shared_ptr& graph) const; + std::unique_ptr dnnl_graph_ = nullptr; + std::unique_ptr aliasDb_ = nullptr; + OpPartitionMap opToOwningPartition_; + std::vector partitions_; + std::map + tensorIdToValue_; // map from tensorId to torch::jit::Value +}; + +class LlgaNodeWrapper { + public: + LlgaNodeWrapper(const Node* node); + + void setOpaqueLayout(size_t offset); + + bool useOpaqueLayout(size_t offset) const; + + friend class LlgaGraphHelper; + + private: + Node* n; +}; + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/guard_shape.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/guard_shape.h new file mode 100644 index 0000000000000000000000000000000000000000..227aa35d10a98e043bcd5ca2ef8a3a3042aef892 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/guard_shape.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace torch::jit::fuser::onednn { + +void prepareFusionGroupAndGuardOutputs(Block* block); + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/interface.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/interface.h new file mode 100644 index 0000000000000000000000000000000000000000..4fd940816308c0a020399fb6c9e7be02d6637b30 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/interface.h @@ -0,0 +1,58 @@ +#pragma once +#include +#include +#include + +namespace torch::jit { +namespace fuser::onednn { + +static std::atomic onednn_enabled{false}; + +static std::atomic& getLlgaEnabled() { + return onednn_enabled; +} + +C10_EXPORT void fuseGraph(std::shared_ptr& g); + +} // namespace fuser::onednn + +struct C10_EXPORT RegisterLlgaFuseGraph + : public PassManager { + static bool setEnabled(bool enabled) { + TORCH_CHECK( + AT_MKLDNN_ENABLED(), + "Running oneDNN Graph fuser is only supported with MKLDNN builds."); + bool oldState = fuser::onednn::getLlgaEnabled(); + fuser::onednn::getLlgaEnabled() = enabled; + if (enabled) { + registerPass(fuser::onednn::fuseGraph); + } else { + clearPass(); + } + return oldState; + } + + static bool isEnabled() { + return fuser::onednn::getLlgaEnabled(); + } + + // override PassManager::registerPass to register pre-pass + static bool registerPass(GraphPass p) { + if (!isRegistered()) { + passID(registerPrePass(std::move(p)), true); + isRegistered(true); + return false; + } + return true; + } + + // override PassManager::clearPass to clear pre-pass + static void clearPass() { + if (isRegistered()) { + clearPrePass(passID()); + isRegistered(true); + } + } +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/kernel.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cf24190d9aac4fd4e98443c18d3a6a051d007ac7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/kernel.h @@ -0,0 +1,89 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +namespace torch::jit::fuser::onednn { + +using ArgSpec = LlgaTensorDesc; +using ArgSpecs = std::vector; +using RunArg = dnnl::graph::tensor; +using RunArgs = std::vector; +using TensorArgs = std::vector; + +class LlgaKernel { + public: + explicit LlgaKernel(const Node* fusionNode); + + void run(Stack& stack); + + void initialize(const TensorArgs& inputs); + + const std::string& debugName() const { + return debugName_; + } + + private: + bool useOpaqueLayout(size_t offset) const; + + // PyTorch copy constants inside the subgraph instead of referencing them. + // Constants inputs to the partition are no longer in the graph->inputs(). + // Need use the tid retrieved from the partition to find the missing + // constant inputs. + void initializeConstantInputs(); + + ArgSpecs initializeInputSpecs(const TensorArgs& inputs); + + ArgSpecs initializeOutputSpecs() const; + + dnnl::graph::compiled_partition compile( + const dnnl::graph::partition& partition); + + std::map initializeTensorIdToOccurence() const; + + std::tuple prepareRunArgs( + const TensorArgs& inputs, + TensorArgs& outputs) const; + + static std::string genDebugName() { + static size_t debugId = 0; + return "LlgaPartition_" + std::to_string(debugId++); + } + + static dnnl::graph::logical_tensor toLogicalTensor(const ArgSpec& s) { + return s.logical_tensor(); + } + + at::Device device_ = at::kCPU; + const Node* fusionNode_; + std::shared_ptr graph_; + int64_t nGraphInputs_ = 0; // number of inputs to graph_ on the IR + int64_t nOutputs_ = 0; + std::map tensorIdToValue_; + std::vector runArgsIdx_; + dnnl::graph::partition partition_; + // nPartitionInputs_ is the actual number of inputs to partition_ of graph_ + // needed by the backend. + // nPartitionInputs_ = nGraphInputs_ + constantInputs_.size() since Constant + // inputs are copied to the inside of the subgraph + int64_t nPartitionInputs_; + dnnl::graph::compiled_partition compilation_; + std::set initializedInputIds_; + std::vector constantValues_; + TensorArgs constantInputs_; + ArgSpecs inputSpecs_; + ArgSpecs outputSpecs_; + std::vector constantLogicalTensors_; + std::string debugName_; + c10::once_flag initialized_flag; + bool is_initialized_ = false; +}; + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/layout_propagation.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/layout_propagation.h new file mode 100644 index 0000000000000000000000000000000000000000..6af79ca78796a5e4e308bc8116d9c2cc37f22808 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/layout_propagation.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace torch::jit::fuser::onednn { + +void PropagateLayout(const std::shared_ptr& graph); + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/operator.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/operator.h new file mode 100644 index 0000000000000000000000000000000000000000..1a40c4438b4d8886a8a946bc36c2d7d3dc5a4161 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/operator.h @@ -0,0 +1,146 @@ +#pragma once + +#include +#include +#include + +namespace torch::jit::fuser::onednn { + +class Operator { + public: + Operator(const Node* node, dnnl::graph::op::kind kind) + : n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) {} + + // Returns output index if the Value is a graph output. + // Otherwise returns -1 + int32_t graphOutputIdx(Value* v) { + int32_t i = 0; + for (const Value* output : v->owningGraph()->outputs()) { + if (v == output) { + return i; + } + i++; + } + return -1; + } + + Operator& setInputValue(Value* v) { + if (v->mustNotBeNone()) { + if (v->type()->kind() == c10::TensorType::Kind) { + o.add_input(createLogicalTensor(v)); + } + } + return *this; + } + + Operator& setInput(size_t offset) { + return setInputValue(n->input(offset)); + } + + template + Operator& setInput(size_t offset, Ts... other) { + setInput(offset); + return setInput(other...); + } + + Operator& setOutputValue(Value* v) { + if (v->mustNotBeNone()) { + o.add_output(createLogicalTensor(v)); + } + return *this; + } + + // setOutputValue & setOutput require a pointer to the LLGA graph, as output + // logical tensors that are graph outputs should be connected to an End LLGA + // op. A value of NULL can be provided for the graph pointer in order to + // maintain the legacy functionality of this function. + Operator& setOutputValue(Value* v, std::unique_ptr& g) { + if (v->mustNotBeNone()) { + auto output_tensor = createLogicalTensor(v); + o.add_output(output_tensor); + if (g) { + int32_t outputIndex = graphOutputIdx(v); + if (outputIndex != -1) { + dnnl::graph::op newEndNode( + LONG_MAX - outputIndex, + dnnl::graph::op::kind::End, + "EndNodeForGraphOutput"); + newEndNode.add_input(output_tensor); + g->add_op(newEndNode); + } + } + } + return *this; + } + + Operator& setOutput(std::unique_ptr& g, size_t offset) { + return setOutputValue(n->output(offset), g); + } + + Operator& setOutput(size_t offset) { + return setOutputValue(n->output(offset)); + } + + template + Operator& setOutput( + std::unique_ptr& g, + size_t offset, + Ts... other) { + setOutput(g, offset); + return setOutput(g, other...); + } + + template + Operator& setAttr(dnnl::graph::op::attr name, Attr&& attr) { + o.set_attr(name, std::forward(attr)); + return *this; + } + + template + Operator& setAttr(dnnl::graph::op::attr name, const F& fn, size_t offset) { + return setAttr(name, fn(n, offset)); + } + + static float ScalarToFloat(const Node* node, size_t offset) { + return toIValue(node->input(offset))->toScalar().to(); + } + + static std::vector Ints(const Node* node, size_t offset) { + return toIValue(node->input(offset))->toIntVector(); + } + + static int64_t Int(const Node* node, size_t offset) { + return toIValue(node->input(offset))->toInt(); + } + + static float Float(const Node* node, size_t offset) { + return static_cast(toIValue(node->input(offset))->toDouble()); + } + + static bool Bool(const Node* node, size_t offset) { + return toIValue(node->input(offset))->toBool(); + } + + static uint64_t getId(const Node* node) { + return reinterpret_cast(node); // cast node address as op id + } + + dnnl::graph::op::kind kind() const { + return k; + } + + dnnl::graph::op llgaOp() const { + return o; + } + + private: + dnnl::graph::logical_tensor createLogicalTensor(Value* value) const { + return LlgaTensorDesc(value).logical_tensor(); + } + + const Node* n; + dnnl::graph::op o; + dnnl::graph::op::kind k; +}; + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/prepare_binary.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/prepare_binary.h new file mode 100644 index 0000000000000000000000000000000000000000..beb66d8822b9d7445ee97072b2d377cf524bc52c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/codegen/onednn/prepare_binary.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace torch::jit::fuser::onednn { + +// Prepare binary ops for LLGA +// +// The pass does the following: +// +// - Convert scalar input of aten::add and aten::mul into Float tensor with +// dimension [1] +// +// - Decompose fused add into aten::mul + aten::add when alpha != 1.0 +// +// - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1 +// +void PrepareBinaryForLLGA(const std::shared_ptr& graph); + +} // namespace torch::jit::fuser::onednn diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/cuda/cuda.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/cuda/cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..dc6d0613976f96cff11cac9ae6d09b261c989bf0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/cuda/cuda.h @@ -0,0 +1,179 @@ +#include +#include +#include +#include + +namespace torch::jit { + +class CUDAEvent; +// This class is a wrapper around c10::cuda::CUDAStream. +// It is needed because TorchBind does not support all of the argument types +// for c10::cuda::CUDAStream. For more details, please refer to +// c10/cuda/CUDAStream.h. +class CUDAStream final : public CustomClassHolder { + public: + CUDAStream( + std::optional device = std::nullopt, + int64_t priority = 0) { + c10::DeviceIndex device_index = + device.has_value() ? device->index() : c10::cuda::current_device(); + stream_ = std::make_unique( + c10::cuda::getStreamFromPool(static_cast(priority), device_index)); + } + + CUDAStream(c10::cuda::CUDAStream s) { + stream_ = std::make_unique(s); + } + + bool query() { + return stream_->query(); + } + + c10::intrusive_ptr recordEvent( + c10::intrusive_ptr event); + + void synchronize() { + stream_->synchronize(); + } + + void waitEvent(const c10::intrusive_ptr& event); + + void waitStream(const c10::intrusive_ptr& stream); + + /// Get the CUDA device index that this stream is associated with. + int64_t device_index() const { + return stream_->device_index(); + } + + /// Get the full Device that this stream is associated with. The Device + /// is guaranteed to be a CUDA device. + c10::Device device() const { + return stream_->device(); + } + + /// Return the stream ID corresponding to this particular stream. + int64_t id() const { + return stream_->id(); + } + + private: + std::unique_ptr stream_; + friend class CUDAEvent; +}; + +// This class is a wrapper around at::cuda::CUDAStream. +// It is needed because TorchBind does not support all of the argument types +// for at::cuda::CUDAEvent. For more details, please refer to +// aten/src/ATen/cuda/CUDAEvent.h. +class CUDAEvent final : public CustomClassHolder { + public: + CUDAEvent( + bool enable_timing = false, + bool blocking = false, + bool interprocess = false) { + int flags = cudaEventDisableTiming; + if (enable_timing) { + flags = cudaEventDefault; + } + if (blocking) { + flags |= cudaEventBlockingSync; + } + if (interprocess) { + TORCH_CHECK(!enable_timing); + flags |= cudaEventInterprocess; + } + + event_ = std::make_unique(flags); + } + + double elapsedTime(const c10::intrusive_ptr& end) { + return event_->elapsed_time(*end->event_); + } + + std::string ipcHandle() { + cudaIpcEventHandle_t handle{}; + event_->ipc_handle(&handle); + std::string str_handle((const char*)&handle, sizeof(handle)); + return str_handle; + } + + bool query() { + return event_->query(); + } + + void record(const c10::intrusive_ptr& stream); + + void synchronize() { + event_->synchronize(); + } + void wait(const c10::intrusive_ptr& stream); + + private: + void recordInternal(CUDAStream* stream); + std::unique_ptr event_; + + friend class CUDAStream; +}; + +inline c10::intrusive_ptr CUDAStream::recordEvent( + c10::intrusive_ptr event) { + if (!event) { + event = c10::make_intrusive(); + } + + event->recordInternal(this); + return event; +} + +inline void CUDAStream::waitEvent(const c10::intrusive_ptr& event) { + event->event_->block(*stream_); +} + +inline void CUDAStream::waitStream( + const c10::intrusive_ptr& stream) { + auto ev = c10::make_intrusive(); + stream->recordEvent(ev); + waitEvent(ev); +} + +inline void CUDAEvent::record(const c10::intrusive_ptr& stream) { + event_->record(*stream->stream_); +} + +inline void CUDAEvent::recordInternal(CUDAStream* stream) { + event_->record(*stream->stream_); +} + +inline void CUDAEvent::wait(const c10::intrusive_ptr& stream) { + event_->block(*stream->stream_); +} + +TORCH_LIBRARY(cuda, m) { + auto stream_class = m.class_("Stream").def( + torch::init, int64_t>(), + "", + {torch::arg("device") = std::nullopt, torch::arg("priority") = 0}); + auto event_class = m.class_("Event").def( + torch::init(), + "", + {torch::arg("enable_timing") = false, + torch::arg("blocking") = false, + torch::arg("interprocess") = false}); + + stream_class.def("query", &CUDAStream::query) + .def("record_event", &CUDAStream::recordEvent) + .def("synchronize", &CUDAStream::synchronize) + .def("wait_event", &CUDAStream::waitEvent) + .def("wait_stream", &CUDAStream::waitStream) + .def("device_index", &CUDAStream::device_index) + .def_property("device", &CUDAStream::device) + .def("id", &CUDAStream::id); + + event_class.def("elapsed_time", &CUDAEvent::elapsedTime) + .def("query", &CUDAEvent::query) + .def("record", &CUDAEvent::record) + .def("synchronize", &CUDAEvent::synchronize) + .def("wait", &CUDAEvent::wait); +} + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/builtin_functions.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/builtin_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..27e190a78a5a8eb6c82f9a9807b203e9cdf33a60 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/builtin_functions.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace torch::jit { + +TORCH_API const std::vector& getAllBuiltinFunctionsFor(Symbol name); +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/canonicalize_modified_loop.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/canonicalize_modified_loop.h new file mode 100644 index 0000000000000000000000000000000000000000..ce78a21689d7be1b7efb44deacddfd17fcb65f6f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/canonicalize_modified_loop.h @@ -0,0 +1,14 @@ +#pragma once +#include + +#include + +namespace torch::jit { + +struct Graph; + +// Transforms loops so that they can be represented as python +// for or while loops +TORCH_API void CanonicalizeModifiedLoops(std::shared_ptr& graph); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/concrete_module_type.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/concrete_module_type.h new file mode 100644 index 0000000000000000000000000000000000000000..f756eda9b4077463730b8b44f6595e4538f86b28 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/concrete_module_type.h @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +enum class IterableModuleKind { NONE, LIST, DICT, PARAMLIST, PARAMDICT }; +class ConcreteModuleType; + +// You can think of an nn.Module as a template that corresponds to a family of +// JIT types. The template "arguments" are things like the constant values. +// e.g. +// class M(nn.Module): +// __constants__ = ["const"] +// ... +// +// Is similar to writing the following in C++: +// +// template +// class M { +// ... +// } +// +// We need to consider each different member of the type family a different JIT +// type because, e.g. different constant values lead to different versions of +// the same method. +// +// ConcreteModuleType corresponds to a single member of the type family, with +// all template arguments fully specified. Two Modules that share a +// ConcreteModuleType can share a JIT type, and vice versa. +// +// Why not just use a JIT type to represent concrete types? Because constants, +// function attributes, etc. are currently not representable in the type system, +// so this acts a non-first-class way of tracking concrete types. +// +// ConcreteModuleType is also the source of truth for servicing all +// ModuleValue::attr calls. This is so we can guarantee that if two Module's +// share a JIT type (and thus a ConcreteModuleType), then they behave the same +// way when you access attributes on them. + +// ConcreteModuleType has two phases. +// 1. Creation: First we build it up, during the ScriptModule conversion +// process. This is represented by ConcreteModuleTypeBuilder. +// ...then the converter calls ConcreteModuleTypeBuilder::build(), producing +// a +// ConcreteModuleType ready for querying. +// 2. Querying: We use ConcreteModuleType as a source of truth for +// ModuleValue::attr calls during method compilation. + +// Represents a concrete type during in the process for construction. We use +// this to decide whether we can share types between modules. +class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { + public: + explicit ConcreteModuleTypeBuilder(py::object pyClass) { + TORCH_INTERNAL_ASSERT(pyClass); + pyClass_ = std::move(pyClass); + } + + void addConstant(std::string name, py::object value); + void addConstant(std::string name, IValue value); + void addAttribute( + std::string name, + const TypePtr& type, + bool isParameter, + bool isBuffer); + void addFunctionAttribute( + std::string name, + const TypePtr& type, + py::object pyFunction); + + void addModule(std::string name, std::shared_ptr meta); + + void addForwardHook(py::object hook); + void addForwardPreHook(py::object pre_hook); + + void addOverload( + std::string methodName, + std::vector overloadedMethodNames); + void addBuiltinFunction(std::string name, const std::string& symbol_name); + void addFailedAttribute(std::string name, std::string failureReason); + void addIgnoredAttribute(std::string name); + void setIterableModuleKind(IterableModuleKind kind); + + // If a ConcreteModuleType is poisoned, it will never compare equal to any + // other concrete type + void setPoisoned(); + + std::shared_ptr build() const { + return std::make_shared(*this); + } + + // This determines whether two modules can share a type. The container structs + // used by ConcreteModuleType have been defined such that operator== + // implements a meaningful comparison in that context. + bool equals(const ConcreteModuleTypeBuilder& other) const; + + struct FunctionAttribute { + FunctionTypePtr function_; + py::object pyFunction_; + + friend bool operator==( + const FunctionAttribute& lhs, + const FunctionAttribute& rhs) { + // Functions are not first class, so we can't do type comparison like a + // regular attribute. So we do a pointer equality check on the actual + // Python function object. + return lhs.pyFunction_.is(rhs.pyFunction_); + } + }; + + struct Attribute { + Attribute(TypePtr type, bool isParam, bool isBuffer) + : type_(std::move(type)), isParam_(isParam), isBuffer_(isBuffer) {} + + friend bool operator==(const Attribute& lhs, const Attribute& rhs) { + return *(lhs.type_) == *(rhs.type_) && lhs.isParam_ == rhs.isParam_; + } + TypePtr type_; + bool isParam_; + bool isBuffer_; + }; + + struct ModuleInfo { + ModuleInfo(std::string name, std::shared_ptr meta) + : name_(std::move(name)), meta_(std::move(meta)) {} + + friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs); + + std::string name_; + std::shared_ptr meta_; + }; + + private: + ConcreteModuleTypeBuilder() = default; + ClassTypePtr createTypeFromThis() const; + + // If true, this type will never compare equally to anything else. This is + // used if we want to ensure that this type is not shared (for example, if it + // came from a traced module) + bool isPoisoned_ = false; + + // The value of any constants defined by the module. + std::unordered_map constants_; + // The types of any attributes + OrderedDict attributes_; + // Overloads, in the same format as `__overloads__` in Python + std::unordered_map> overloads_; + // Any attributes we failed to convert to TorchScript, along with a hint as to + // why + std::unordered_map failedAttributes_; + // Any attributes that were marked as ignored. They cannot be used in + // TorchScript but can still be used in ignored function in Python. + std::unordered_set ignoredAttributes_; + // Any function attributes. These are special right now because functions are + // not first-class in the type system. + std::unordered_map functionAttributes_; + // Function attributes that are calls to builtin functions. These get + // de-sugared directly into the corresponding aten:: call. The map is + // attribute name -> aten symbol name + std::unordered_map builtinFunctions_; + // The concrete types of any submodules + std::vector modules_; + // Hooks to be called before/after forward when the module + // is called directly. Used to ensure modules have different types + // when they have different python hooks + // Actual hooks are added to ClassType directly during compilation + std::vector forwardHooks_; + std::vector forwardPreHooks_; + + // If something is a ModuleDict/ModuleList, it means: + // 1. The order of the submodules matters for comparing the type + // 2. The compiler is allowed to treat it like a dict/tuple + IterableModuleKind iterableModuleKind_ = IterableModuleKind::NONE; + + // The original `nn.Module` class that we derived this ScriptModule from. + py::object pyClass_; + + // NOTE: If you ever add any more state to this struct, you need to make sure + // operator== still makes sense! + friend ConcreteModuleType; +}; + +// Represents a finalized concrete type, used to service ModuleValue::attr calls +// during method compilation. +class VISIBILITY_HIDDEN ConcreteModuleType { + public: + explicit ConcreteModuleType(ConcreteModuleTypeBuilder data); + + static std::shared_ptr fromJitType(TypePtr type); + + TypePtr getJitType() const; + std::optional getPyClass() const; + IterableModuleKind getIterableModuleKind() const; + std::optional> findOverloads( + const std::string& name) const; + std::optional findFunctionAttribute(const std::string& name) const; + std::optional findBuiltinFunction(const std::string& name) const; + std::shared_ptr findSubmoduleConcreteType( + const std::string& name) const; + std::optional findFailedAttribute(const std::string& name) const; + bool isIgnoredAttribute(const std::string& name) const; + + // These getters are only here to return things as types that can be + // automatically converted by pybind. + std::unordered_map getConstantsPy() const; + std::unordered_map> getAttributesPy() + const; + std::vector>> + getModulesPy() const; + + bool equals(const ConcreteModuleType& other) const { + if (jitType_ == other.jitType_) { + // If the computed types are the same, these modules can (obviously) share + // a type. + return true; + } + + return data_.equals(other.data_); + } + bool equals(const ConcreteModuleTypeBuilder& other) const { + return data_.equals(other); + } + + void dump() const; + + private: + ConcreteModuleType() = default; + + // The JIT type derived from this ConcreteModuleType. + ConcreteModuleTypeBuilder data_; + TypePtr jitType_; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/convert_to_ssa.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/convert_to_ssa.h new file mode 100644 index 0000000000000000000000000000000000000000..9ea8bc8cb3819fba75941d4b7f4224436b55aeaf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/convert_to_ssa.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include +#include + +#include +#include + +namespace torch::jit { + +// Convert a graph with Loads & Stores into SSA form +TORCH_API void ConvertToSSA(std::shared_ptr& graph); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/edit_distance.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/edit_distance.h new file mode 100644 index 0000000000000000000000000000000000000000..761e7ff50f022210f43ce9c32b1121c99352a797 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/edit_distance.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +namespace torch::jit { + +TORCH_API size_t ComputeEditDistance( + const char* word1, + const char* word2, + size_t maxEditDistance); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/error_report.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/error_report.h new file mode 100644 index 0000000000000000000000000000000000000000..635dd35468e3b3c83d2fe993868973fb8297c6d9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/error_report.h @@ -0,0 +1,51 @@ +#pragma once + +#include + +namespace torch::jit { + +struct Call { + std::string fn_name; + SourceRange caller_range; +}; + +struct TORCH_API ErrorReport : public std::exception { + ErrorReport(const ErrorReport& e); + + explicit ErrorReport(const SourceRange& r); + explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {} + explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {} + + const char* what() const noexcept override; + + struct TORCH_API CallStack { + // These functions are used to report why a function was being compiled + // (i.e. what was the call stack of user functions at compilation time that + // led to this error) + CallStack(const std::string& name, const SourceRange& range); + ~CallStack(); + + // Change the range that is relevant for the current function (i.e. after + // each successful expression compilation, change it to the next expression) + static void update_pending_range(const SourceRange& range); + }; + + static std::string current_call_stack(); + + private: + template + friend const ErrorReport& operator<<(const ErrorReport& e, const T& t); + + mutable std::stringstream ss; + OwnedSourceRange context; + mutable std::string the_message; + std::vector error_stack; +}; + +template +const ErrorReport& operator<<(const ErrorReport& e, const T& t) { + e.ss << t; + return e; +} + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/exit_transforms.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/exit_transforms.h new file mode 100644 index 0000000000000000000000000000000000000000..94a983ce388b726433c7474028aa079edfa29fee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/exit_transforms.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +namespace torch::jit { + +TORCH_API void TransformExits(std::shared_ptr& graph); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/function_schema_parser.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/function_schema_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..c1a560181f888c83802f81af14e40dbdff55e0ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/function_schema_parser.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::jit { + +// allow_typevars: If true, we assume that lowercase types that we don't +// understand are type variables. This is only needed for TorchScript (and not +// not needed for custom ops). +// If false, we disallow typevars, except in certain cases for BC reason (i.e. +// your op is in the aten or prim namespace). +TORCH_API std::variant parseSchemaOrName( + const std::string& schemaOrName, + bool allow_typevars = true); +TORCH_API c10::FunctionSchema parseSchema( + const std::string& schema, + bool allow_typevars = true); +TORCH_API c10::OperatorName parseName(const std::string& name); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/inline_loop_condition.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/inline_loop_condition.h new file mode 100644 index 0000000000000000000000000000000000000000..74ba37411a9ae6c7405cf74c9fe449080cca3a2b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/inline_loop_condition.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include +#include + +#include +#include + +namespace torch::jit { + +TORCH_API void InlineLoopCondition(std::shared_ptr& graph); +TORCH_API void InlineBlockBeforeNode(Node* before_node, Block* block); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/ir_emitter.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/ir_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..a4aee2b7e281ef94c3383b783db370d19ecd0984 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/ir_emitter.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +TORCH_API void runCleanupPasses(std::shared_ptr& to_clean); + +TORCH_API bool meaningfulName(const std::string& name); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/lexer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..0faf6ff24da45295ae61d78dfaf387b4046b4971 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/lexer.h @@ -0,0 +1,567 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +// single character tokens are just the character itself '+' +// multi-character tokens need an entry here +// if the third entry is not the empty string, it is used +// in the lexer to match this token. + +// These kinds are also used in Tree.h as the kind of the AST node. +// Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the +// lexer. + +#define TC_FORALL_TOKEN_KINDS(_) \ + _(TK_EOF, "eof", "") \ + _(TK_WHITESPACE, "whitespace", "") \ + _(TK_WHITESPACE_EOF, "whitespace_eof", "") \ + _(TK_NUMBER, "number", "") \ + _(TK_NEWLINE, "newline", "") \ + _(TK_INDENT, "indent", "") \ + _(TK_DEDENT, "dedent", "") \ + _(TK_DEF, "def", "def") \ + _(TK_EQUIVALENT, "equivalent", "<=>") \ + _(TK_IDENT, "ident", "") \ + _(TK_STRING, "string", "") \ + _(TK_STRINGLITERAL, "string_literal", "") \ + _(TK_CONST, "const", "") \ + _(TK_LIST, "list", "") \ + _(TK_DICT, "dict", "") \ + _(TK_OPTION, "option", "") \ + _(TK_APPLY, "apply", "") \ + _(TK_COMPREHENSION, "comprehension", "") \ + _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ + _(TK_PARAM, "param", "") \ + _(TK_INFERRED, "inferred", "") \ + _(TK_ACCESS, "access", "") \ + _(TK_ASSIGN, "assign", "") \ + _(TK_AUG_ASSIGN, "aug_assign", "") \ + _(TK_ATTRIBUTE, "attribute", "") \ + _(TK_IF, "if", "if") \ + _(TK_ELSE, "else", "else") \ + _(TK_ELIF, "elif", "elif") \ + _(TK_WHILE, "while", "while") \ + _(TK_EXPR_STMT, "expression statement", "") \ + _(TK_RETURN, "return", "return") \ + _(TK_IS, "is", "is") \ + _(TK_ISNOT, "is not", "is not") \ + _(TK_NE, "ne", "!=") \ + _(TK_EQ, "eq", "==") \ + _(TK_LE, "le", "<=") \ + _(TK_GE, "ge", ">=") \ + _(TK_FLOOR_DIV, "floordiv", "//") \ + _(TK_IF_EXPR, "if", "") \ + _(TK_TRUE, "True", "True") \ + _(TK_FALSE, "False", "False") \ + _(TK_NONE, "None", "None") \ + _(TK_AND, "and", "and") \ + _(TK_OR, "or", "or") \ + _(TK_NOT, "not", "not") \ + _(TK_LSHIFT, "<<", "<<") \ + _(TK_RSHIFT, ">>", ">>") \ + _(TK_CAST, "cast", "") \ + _(TK_PLUS_EQ, "+=", "+=") \ + _(TK_MINUS_EQ, "-=", "-=") \ + _(TK_TIMES_EQ, "*=", "*=") \ + _(TK_DIV_EQ, "/=", "/=") \ + _(TK_MOD_EQ, "%=", "%=") \ + _(TK_BIT_OR_EQ, "|=", "|=") \ + _(TK_BIT_AND_EQ, "&=", "&=") \ + _(TK_BIT_XOR_EQ, "^=", "^=") \ + _(TK_LSHIFT_EQ, "<<=", "<<=") \ + _(TK_RSHIFT_EQ, ">>=", ">>=") \ + _(TK_POW_EQ, "**=", "**=") \ + _(TK_GLOBAL, "global", "global") \ + _(TK_BUILT_IN, "built-in", "") \ + _(TK_SUBSCRIPT, "subscript", "") \ + _(TK_VAR, "variable", "") \ + _(TK_NOTHING, "nothing", "") \ + _(TK_DICT_LITERAL, "dict-literal", "") \ + _(TK_LIST_LITERAL, "list-literal", "") \ + _(TK_TUPLE_LITERAL, "tuple-literal", "") \ + _(TK_FOR, "for", "for") \ + _(TK_IN, "in", "in") \ + _(TK_NOTIN, "not in", "not in") \ + _(TK_STARRED, "starred", "") \ + _(TK_UNARY_MINUS, "unary minus", "") \ + _(TK_POW, "pow operator", "**") \ + _(TK_ARROW, "arrow", "->") \ + _(TK_DECL, "decl", "") \ + _(TK_SLICE_EXPR, "slice expr", "") \ + _(TK_TYPE_COMMENT, "type comment", "# type:") \ + _(TK_RAISE, "raise", "raise") \ + _(TK_ASSERT, "assert", "assert") \ + _(TK_DOTS, "dots", "...") \ + _(TK_LIST_COMP, "list comprehension", "") \ + _(TK_DICT_COMP, "dict comprehension", "") \ + _(TK_BREAK, "break", "break") \ + _(TK_CONTINUE, "continue", "continue") \ + _(TK_DELETE, "del", "del") \ + _(TK_PASS, "pass", "pass") \ + _(TK_CLASS_DEF, "class", "class") \ + _(TK_IMPORT, "import", "import") \ + _(TK_WITH, "with", "with") \ + _(TK_WITH_ITEM, "withitem", "") \ + _(TK_AS, "as", "as") \ + _(TK_PROP, "property", "") \ + _(TK_ELLIPSIS, "Ellipsis", "Ellipsis") \ + _(TK_NONE_TYPE, "NoneType", "NoneType") + +enum TokenKind { + // we use characters to represent themselves so skip all valid characters + // before + // assigning enum values to multi-char tokens. + TK_DUMMY_START = 256, +#define DEFINE_TOKEN(tok, _, _2) tok, + TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN) +#undef DEFINE_TOKEN +}; + +TORCH_API std::string kindToString(int kind); +TORCH_API int stringToKind(const std::string& str); + +// nested hash tables that indicate char-by-char what is a valid token. +struct TokenTrie; +using TokenTrieRef = std::unique_ptr; +struct TokenTrie { + TokenTrie() = default; + void insert(const char* str, int tok) { + if (*str == '\0') { + AT_ASSERT(kind == 0); + kind = tok; + return; + } + + for (size_t i = 0, e = child_chars.size(); i < e; ++i) { + if (child_chars[i] == *str) { + child_tries[i]->insert(str + 1, tok); + return; + } + } + + child_chars.emplace_back(*str); + child_tries.emplace_back(std::make_unique()); + child_tries.back()->insert(str + 1, tok); + } + int kind{0}; // 0 == invalid token + + std::vector child_chars; + std::vector child_tries; +}; + +// stuff that is shared against all TC lexers/parsers and is initialized only +// once. +struct TORCH_API SharedParserData { + SharedParserData() : head(new TokenTrie()) { + for (const char* c = valid_single_char_tokens; *c; c++) { + std::string str(1, *c); + head->insert(str.c_str(), *c); + } + +#define ADD_CASE(tok, _, tokstring) \ + if (*(tokstring) != '\0') { \ + head->insert((tokstring), (tok)); \ + } + TC_FORALL_TOKEN_KINDS(ADD_CASE) +#undef ADD_CASE + } + + bool match( + StringCordView::Iterator pos, + bool continuation, // are we inside a scope where newlines don't count + // (e.g. inside parens) + bool whitespace_token, // should we treat whitespace as a token + int* kind, + StringCordView::Iterator* start, + StringCordView::Iterator* end) { + *start = pos; + // skip whitespace + while (pos.has_next() && isblank(*pos)) { + ++pos; + } + + // special handling + if (pos.has_next()) { + if (*pos == '#' && !isTypeComment(pos)) { + // skip comments + while (pos.has_next() && *pos != '\n') + ++pos; + // tail call, handle whitespace and more comments + return match(pos, continuation, whitespace_token, kind, start, end); + } + if (*pos == '\\') { + auto newiter = pos; + ++newiter; + if (newiter.has_next() && *newiter == '\n' && !whitespace_token) { + ++newiter; + return match(newiter, continuation, false, kind, start, end); + } + } + if (*pos == '\n') { + return match(++pos, continuation, !continuation, kind, start, end); + } + } + // we handle white space before EOF because in the case we have something + // like the following where we need to generate the dedent token if foo: + // ... + // else: + // pass + if (whitespace_token) { + *kind = !pos.has_next() ? TK_WHITESPACE_EOF : TK_WHITESPACE; + *end = pos; + return true; + } + if (!pos.has_next()) { + *kind = TK_EOF; + *start = pos; + *end = *start; + return true; + } + // invariant: the next token is not whitespace or newline + *start = pos; + // check for a valid number + size_t len = 0; + if (isNumber(pos.rest_line(), 0, &len)) { + *end = *start; + *end += len; + *kind = TK_NUMBER; + return true; + } + // check for string + if (isString(pos.rest_line(), 0, &len)) { + *kind = TK_STRINGLITERAL; + *end = *start; + *end += len; + return true; + } + + // check for either an ident or a token + // ident tracks whether what we have scanned so far could be an identifier + // matched indicates if we have found any match. + bool matched = false; + bool ident = true; + TokenTrie* cur = head.get(); + // for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); + // i++) + for (size_t i = 0; pos.has_next() && (ident || cur != nullptr); + ++pos, ++i) { + ident = ident && validIdent(i, *pos); + if (ident) { + matched = true; + *end = pos.next_iter(); + *kind = TK_IDENT; + } + // check for token second, so that e.g. 'max' matches the token TK_MAX + // rather the + // identifier 'max' + if (cur) { + const auto begin_it = cur->child_chars.begin(); + const auto end_it = cur->child_chars.end(); + const auto ch_it = std::find(begin_it, end_it, *pos); + + cur = (ch_it == end_it) ? nullptr + : cur->child_tries[ch_it - begin_it].get(); + + if (cur && cur->kind != 0) { + matched = true; + *end = pos.next_iter(); + *kind = cur->kind; + } + } + } + return matched; + } + + bool isUnary(int kind, int* prec); + bool isBinary(int kind, int* prec); + bool isRightAssociative(int kind) { + switch (kind) { + case '?': + case TK_POW: + case TK_IF: + return true; + default: + return false; + } + } + + private: + bool validIdent(size_t i, char n) { + return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); + } + + // 1. skip whitespace + // 2. handle comment or newline + // + bool isNumber(std::string_view str, size_t start, size_t* len) { + char first = str[start]; + // strtod allows numbers to start with + or - or nan or inf + // http://en.cppreference.com/w/cpp/string/byte/strtof + // but we want only the number part, otherwise 1+3 will turn into two + // adjacent numbers in the lexer + if (first == '-' || first == '+' || isalpha(first)) + return false; + const char* startptr = str.data() + start; + char* endptr = nullptr; + torch::jit::strtod_c(startptr, &endptr); + *len = endptr - startptr; + // check if the number is complex valued + // access is safe because string is assumed to be null terminated + if (endptr != nullptr && *endptr == 'j') { + *len += 1; + } + return *len > 0; + } + + bool isCharCount(char c, std::string_view str, size_t start, int len) { + // count checks from [start, start + len) + return start + len <= str.size() && + std::count(str.begin() + start, str.begin() + start + len, c) == len; + } + + // python concatenates all adjacent strings "a" "b" == "ab" + // strings can be enclosed with 1 or 3 single or double quotes + // if enclosed with 3 quotes newlines are valid + // as elsewhere, backslash and new line should be ignored + bool isString(std::string_view str, size_t start, size_t* len) { + char quote = str[start]; + if (quote != '\"' && quote != '\'') + return false; + int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1; + + // end is now set past the opening quotation marks + size_t end = start + quote_len; + while (end < str.size() && !isCharCount(quote, str, end, quote_len)) { + if (str[end] == '\n' && quote_len != 3) { + return false; + } + // handle escaped characters. advances past escaped quotation marks, + // escaped newlines and escaped backslashes + // multi-char escapes like \x1A are handled fine here because the + // remainder of the escape are valid string characters anyway + if (str[end] == '\\') { + end++; + } + end++; + } + // set length equal to the complete string including quotations + *len = end - start + quote_len; + // if end finished without going past the last character of the string than + // there is a match + return end < str.size(); + } + + bool isblank(int n) { + return isspace(n) && n != '\n'; + } + + bool isTypeComment(StringCordView::Iterator str_iter) { + std::string_view rest_line = str_iter.rest_line(); + const std::string type_string = "# type:"; + if (rest_line.size() < type_string.length()) { + return false; + } + auto match_string = rest_line.substr(0, type_string.size()); + return match_string == type_string; + } + + // Make an exception ignoring comments for type annotation comments + bool isTypeComment(const StringCordView& str, size_t pos) { + const std::string type_string = "# type:"; + if (str.size() < pos + type_string.length()) { + return false; + } + auto match_string = str.substr(pos, type_string.size()); + return match_string == type_string; + } + + TokenTrieRef head; +}; + +TORCH_API SharedParserData& sharedParserData(); + +struct Token { + int kind; + SourceRange range; + Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {} + std::string text() const { + return std::string(range.token_text()); + } + + std::string_view text_view() const { + return range.token_text(); + } + + std::string kindString() const { + return kindToString(kind); + } +}; + +struct Lexer { + explicit Lexer(std::shared_ptr source) + : source(std::move(source)), + + indent_stack(), + next_tokens(), + shared(sharedParserData()) { + auto first_indent = lexRaw(true); + indent_stack.push_back(first_indent.range.size()); + lex(); + } + // Return the current token, and then move to the next one + Token next() { + if (next_tokens.empty()) + reportError("Lexer invariant violated: empty token queue"); + Token r = std::move(next_tokens.front()); + next_tokens.erase(next_tokens.begin()); + if (next_tokens.empty()) { + lex(); + } + return r; + } + // Skip the current token if it matches the given kind + bool nextIf(int kind) { + if (cur().kind != kind) + return false; + next(); + return true; + } + + [[noreturn]] void reportError(const std::string& what) { + reportError(what, cur()); + } + [[noreturn]] void reportError(const std::string& what, const Token& t) { + std::stringstream ss; + ss << what << ":\n"; + t.range.highlight(ss); + throw std::runtime_error(ss.str()); + } + [[noreturn]] void expected(const std::string& what, const Token& t) { + std::stringstream ss; + ss << "expected " << what << " but found '" << t.kindString() + << "' here:\n"; + t.range.highlight(ss); + throw std::runtime_error(ss.str()); + } + [[noreturn]] void expected(const std::string& what) { + expected(what, cur()); + } + // Check that the current token has a given kind, return the current token, + // and advance to the next one. + Token expect(int kind) { + if (cur().kind != kind) { + expected(kindToString(kind)); + } + return next(); + } + Token& lookahead() { + if (next_tokens.size() < 2) { + lex(); + } + return next_tokens[1]; + } + Token& cur() { + return next_tokens.front(); + } + + private: + void lex() { + auto r = lexRaw(); + switch (r.kind) { + case '(': + case '[': + case '{': + nesting++; + break; + case ')': + case ']': + case '}': + nesting--; + break; + case TK_WHITESPACE: + case TK_WHITESPACE_EOF: { + const auto depth = + r.kind == TK_WHITESPACE_EOF ? indent_stack.front() : r.range.size(); + // note: TK_WHITESPACE_EOF is whitespace right before the EOF token + // just like we allow the code to be indented to a particular initial + // indent level, we allow the final indent to be anything and set + // it back to the initial indent level. This allows the code to be + // put into string literals inside code without worrying about final + // whitespace + if (depth > indent_stack.back()) { + indent_stack.push_back(depth); + r.kind = TK_INDENT; + } else if (depth == indent_stack.back()) { + r.kind = TK_NEWLINE; + } else { + next_tokens.emplace_back(TK_NEWLINE, r.range); + while (indent_stack.back() != depth) { + indent_stack.pop_back(); + next_tokens.emplace_back(TK_DEDENT, r.range); + if (indent_stack.empty()) { + reportError("invalid indent level " + std::to_string(depth), r); + } + } + return; // We've already queued the tokens + } + } break; + default: + break; + } + next_tokens.push_back(std::move(r)); + } + Token lexRaw(bool whitespace_token = false) { + AT_ASSERT(source); + if (current == nullptr) { + AT_ASSERT(pos == 0); + current = std::make_unique( + source->text_str().begin()); + } + + StringCordView::Iterator start_iter = *current; + StringCordView::Iterator end_iter = *current; + int kind = 0; + if (!shared.match( + *current, + nesting > 0, + whitespace_token, + &kind, + &start_iter, + &end_iter)) { + expected( + "a valid token", + Token( + **current, + SourceRange(source, start_iter, start_iter.pos() + 1))); + } + + auto t = Token(kind, SourceRange(source, start_iter, end_iter.pos())); + pos = end_iter.pos(); + *current = end_iter; + return t; + } + + std::shared_ptr source; + std::unique_ptr current; + size_t pos{0}; + size_t nesting{0}; // depth of ( [ { nesting... + std::vector indent_stack; // stack of indentation level of blocks + // Invariant: this should always contain at least a single element + std::vector next_tokens; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + SharedParserData& shared; +}; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/mini_environment.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/mini_environment.h new file mode 100644 index 0000000000000000000000000000000000000000..1b71927ffd594c80abd6f0a9eab7f938723cd9d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/mini_environment.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +namespace torch::jit { + +// Simple data structure for containing a type T in nested control blocks +// Should only be used after initial compilation where type checking and +// loads and stores are emitted + +template +struct MiniEnvironment { + MiniEnvironment(Block* b, std::shared_ptr next = nullptr) + : next(std::move(next)) {} + + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::shared_ptr> next; + + T findInThisFrame(const std::string& name) { + auto it = table.find(name); + if (it != table.end()) { + return it->second; + } + return nullptr; + } + + T findInAnyFrame(const std::string& name) { + for (auto runner = this; runner; runner = runner->next.get()) { + if (auto r = runner->findInThisFrame(name)) { + return r; + } + } + return nullptr; + } + + void setVar(const std::string& name, T value) { + table[name] = value; + } + + std::vector definedVariables() { + std::vector result; + result.reserve(table.size()); + for (auto& kv : table) { + result.push_back(kv.first); + } + std::sort(result.begin(), result.end()); + return result; + } + + private: + std::unordered_map table; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/name_mangler.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/name_mangler.h new file mode 100644 index 0000000000000000000000000000000000000000..2f436f91a1f3e64b6304a5bd47f703390863b789 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/name_mangler.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace torch::jit { + +/** + * class NameMangler + * + * Utility to mangle qualified names in order to make them unique. We use this + * in various places where we to de-duplicate qualified names. + */ +class TORCH_API NameMangler { + public: + // Given a qualified name, return a mangled version that is guaranteed to be + // unique with respect to previous/future calls of `mangled()` on this name + // mangler instance. + c10::QualifiedName mangle(const c10::QualifiedName& name); + + private: + size_t mangleIndex_ = 0; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parse_string_literal.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parse_string_literal.h new file mode 100644 index 0000000000000000000000000000000000000000..5139ae9ec790ad96868535c2329d6128754d6f07 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parse_string_literal.h @@ -0,0 +1,87 @@ +#pragma once +#include +#include +#include + +namespace torch::jit { + +inline bool isCharCount(char c, const std::string& str, size_t start, int len) { + // count checks from [start, start + len) + return start + len <= str.size() && + std::count( + str.begin() + static_cast(start), + str.begin() + static_cast(start + len), + c) == len; +} + +inline std::optional parseOctal(const std::string& str, size_t pos) { + //\xxx where x are 0-7 + if (pos + 3 >= str.size()) + return std::nullopt; + size_t c = 0; + for (size_t i = 1, b = 64; i < 4; ++i, b /= 8) { + auto d = str[pos + i]; + if (d < '0' || d > '7') + return std::nullopt; + c += b * (d - '0'); + } + if (c >= 256) + return std::nullopt; + return c; +} + +inline std::string parseStringLiteral( + const SourceRange& range, + const std::string& str) { + size_t quote_len = isCharCount(str[0], str, 0, 3) ? 3 : 1; + auto ret_str = str.substr(quote_len, str.size() - quote_len * 2); + size_t pos = ret_str.find('\\'); + while (pos != std::string::npos) { + // invariant: pos has to escape a character because it is a valid string + char c = ret_str[pos + 1]; + size_t to_erase = 2; + switch (ret_str[pos + 1]) { + case '\\': + case '\'': + case '\"': + case '\n': + break; + case 'a': + c = '\a'; + break; + case 'b': + c = '\b'; + break; + case 'f': + c = '\f'; + break; + case 'n': + c = '\n'; + break; + case 'v': + c = '\v'; + break; + case 't': + c = '\t'; + break; + case 'x': + throw(ErrorReport(range) << "unsupported hex specifier"); + case 'u': + case 'U': + throw(ErrorReport(range) << "unsupported unicode specifier"); + default: + // octal value in format \nnn, n is [0-7] + if (auto v = parseOctal(ret_str, pos)) { + to_erase = 4; + c = *v; + } else { + throw(ErrorReport(range) << " ill formed octal specifier"); + } + } + ret_str.replace(pos, to_erase, /* num copies */ 1, c); + pos = ret_str.find('\\', pos + 1); + } + return ret_str; +} + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parser.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parser.h new file mode 100644 index 0000000000000000000000000000000000000000..7f4d17b0ce8dc0ddfdfca4720e7aa66e6fee8798 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parser.h @@ -0,0 +1,31 @@ +#pragma once +#include +#include +#include +#include + +namespace torch::jit { + +struct Decl; +struct ParserImpl; +struct Lexer; + +TORCH_API Decl mergeTypesFromTypeComment( + const Decl& decl, + const Decl& type_annotation_decl, + bool is_method); + +struct TORCH_API Parser { + explicit Parser(const std::shared_ptr& src); + TreeRef parseFunction(bool is_method); + TreeRef parseClass(); + Decl parseTypeComment(); + Expr parseExp(); + Lexer& lexer(); + ~Parser(); + + private: + std::unique_ptr pImpl; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parser_constants.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parser_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..fb5cf0d88e1e16142a9de2c9ece8068af66dec1b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/parser_constants.h @@ -0,0 +1,6 @@ +#pragma once + +namespace torch::jit { +static constexpr const char* valid_single_char_tokens = + "+-*/%@()[]:,={}><.?!&^|~"; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/resolver.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..d5b0f1954c833d10cf778667e01cfed72c9f4fa5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/resolver.h @@ -0,0 +1,66 @@ +#pragma once + +#include +#include +#include + +namespace torch::jit { + +struct Resolver; +using ResolverPtr = std::shared_ptr; + +/** + * class Resolver + * + * Represents an "outer environment" in which we an look up names and return + * a corresponding SugaredValue. This is used during compilation to resolve + * references to names which are not defined internal to the graph. + * + * Example: PythonResolver looks at the enclosing Python scope for `name`. + * + * NOTE: When adding methods, keep this an abstract class (i.e. all new methods + * should be purely virtual). Resist the urge to provide a default + * implementation; you should explicitly think about how each resolver would + * handle the method. + */ +struct Resolver { + virtual ~Resolver() = default; + + // Resolve a given name to a SugaredValue. This takes the method `m` that the + // caller is currently constructing, since we may need to insert nodes into + // the graph to create a value. + virtual std::shared_ptr resolveValue( + const std::string& name, + GraphFunction& m, + const SourceRange& loc) { + return nullptr; + } + + // Resolve `name` to a TypePtr. + virtual TypePtr resolveType(const std::string& name, const SourceRange& loc) { + return nullptr; + } +}; + +// A resolver that only understands "torch.foo()" lookups. +struct NativeResolver : public Resolver { + std::shared_ptr resolveValue( + const std::string& name, + GraphFunction& m, + const SourceRange& loc) override { + if (name == "torch") { + return std::make_shared("aten"); + } + return nullptr; + } + + TypePtr resolveType(const std::string& name, const SourceRange& loc) + override { + return nullptr; + } +}; + +inline std::shared_ptr nativeResolver() { + return std::make_shared(); +} +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/schema_matching.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/schema_matching.h new file mode 100644 index 0000000000000000000000000000000000000000..ddc6f1f22dd118e02f5676b6f5eba6af5664da6a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/schema_matching.h @@ -0,0 +1,68 @@ +#pragma once +#include +#include +#include + +#include + +namespace torch::jit { + +// Try to match a list of inputs and keyword 'attributes' to this +// schema. Return the flat list of positional inputs to the call or +// `std::nullopt` on failure (`failure_messages` contains a good error +// report in this case) + +struct MatchedSchema { + std::vector inputs; + std::vector return_types; + c10::OptNameList return_field_names; + std::string schema_name; +}; + +TORCH_API bool isBlockListedSchema(const FunctionSchema& schema); + +TORCH_API MatchedSchema matchSchema( + const ::c10::FunctionSchema& schema, + const SourceRange& loc, + Graph& graph, + at::ArrayRef args, + at::ArrayRef kwargs, + const std::optional& self = std::nullopt); + +TORCH_API std::pair matchSchemas( + const std::vector& schemas, + const SourceRange& loc, + Graph& graph, + at::ArrayRef args, + at::ArrayRef kwargs, + const std::optional& self = std::nullopt, + bool render_errors = false); + +TORCH_API bool convertibleToList( + const TypePtr& type, + const TypePtr& list_type_); + +TORCH_API std::string getFullSchemaName(const ::c10::FunctionSchema& schema); + +TORCH_API Value* emitBuiltinCall( + const SourceRange& loc, + Graph& graph, + Symbol name, + at::ArrayRef args, + at::ArrayRef kwargs, + const std::optional& self = std::nullopt); + +TORCH_API std::optional findInputWithName( + const std::string& name, + at::ArrayRef kwargs, + bool is_aten = false); + +// applies implicit conversion from value trying to turn it into type +// concrete_type it succeeds if the return_value->isSubtypeOf(concrete_type) +TORCH_API Value* tryConvertToType( + const SourceRange& loc, + Graph& graph, + const TypePtr& concrete_type, + Value* value, + bool allow_conversions); +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/schema_type_parser.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/schema_type_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..ca5a00ecaa3fbd72d9d23d95bd509fc4b2aa70f4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/schema_type_parser.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::jit { + +using TypePtr = c10::TypePtr; + +struct TORCH_API SchemaTypeParser { + TypePtr parseBaseType(); + std::optional parseAliasAnnotation(); + std::pair> parseType(); + std::tuple> + parseFakeAndRealType(); + std::optional parseTensorDType(const std::string& dtype); + TypePtr parseRefinedTensor(); + + SchemaTypeParser( + Lexer& L, + bool parse_complete_tensor_types, + bool allow_typevars) + : complete_tensor_types(parse_complete_tensor_types), + L(L), + allow_typevars_(allow_typevars) {} + + private: + std::optional tryToParseRequiresGrad(); + std::optional tryToParseDeviceType(); + void parseList( + int begin, + int sep, + int end, + c10::function_ref callback); + + bool complete_tensor_types; + Lexer& L; + size_t next_id = 0; + bool allow_typevars_; +}; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/script_type_parser.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/script_type_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..205727fe6d6546d6887bc3221ac4215a79b7415e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/script_type_parser.h @@ -0,0 +1,53 @@ +#pragma once +#include +#include +#include +#include + +namespace torch::jit { + +/** + * class ScriptTypeParser + * + * Parses expressions in our typed AST format (TreeView) into types and + * typenames. + */ +class TORCH_API ScriptTypeParser { + public: + explicit ScriptTypeParser() = default; + explicit ScriptTypeParser(ResolverPtr resolver) + : resolver_(std::move(resolver)) {} + + c10::TypePtr parseTypeFromExpr(const Expr& expr) const; + + std::optional> parseBroadcastList( + const Expr& expr) const; + + c10::TypePtr parseType(const std::string& str); + + FunctionSchema parseSchemaFromDef(const Def& def, bool skip_self); + + c10::IValue parseClassConstant(const Assign& assign); + + private: + c10::TypePtr parseTypeFromExprImpl(const Expr& expr) const; + + std::optional parseBaseTypeName(const Expr& expr) const; + at::TypePtr subscriptToType( + const std::string& typeName, + const Subscript& subscript) const; + std::vector evaluateDefaults( + const SourceRange& r, + const std::vector& default_types, + const std::vector& default_exprs); + std::vector parseArgsFromDecl(const Decl& decl, bool skip_self); + + std::vector parseReturnFromDecl(const Decl& decl); + + ResolverPtr resolver_ = nullptr; + + // Need to use `evaluateDefaults` in serialization + friend struct ConstantTableValue; + friend struct SourceImporterImpl; +}; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/source_range.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/source_range.h new file mode 100644 index 0000000000000000000000000000000000000000..5d2eeb42e7e754e12c0f146a091aafcb36d22063 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/source_range.h @@ -0,0 +1,603 @@ +#pragma once +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +class SourceRangeUnpickler; +struct SourceRange; + +// A stringlike class backed by a vector of string_view +// the string represented are logically the concatenation of the string_views +// This has advantage of not needing continues memory. +struct TORCH_API StringCordView { + StringCordView(); + StringCordView(const StringCordView&) = default; + StringCordView(StringCordView&&) noexcept = default; + StringCordView( + std::vector inputs, + std::vector> ownerships); + + StringCordView& operator=(const StringCordView&) = default; + StringCordView& operator=(StringCordView&&) noexcept = default; + + size_t size() const { + return accumulated_sizes_.back(); + } + + size_t find(const std::string& tok, size_t start) const; + size_t find_regex(const std::string& tok, size_t start) const; + StringCordView substr(size_t start, size_t size) const; + + char at(size_t index) const { + return *iter_for_pos(index); + } + char operator[](size_t index) const { + return at(index); + } + + std::string str() const { + std::stringstream ss; + for (auto s : pieces_) { + ss << std::string(s); + } + return ss.str(); + } + + bool operator==(const std::string& rhs) const; + + bool operator==(const StringCordView& rhs) const; + + std::string_view piece(size_t index) const { + return pieces_[index]; + } + + // General-case iterator implementation. + struct IteratorImpl { + IteratorImpl( + const StringCordView* str, + size_t start_line, + size_t start_pos, + size_t size) + : line_(start_line), pos_(start_pos), str_(str), size_(size) {} + explicit IteratorImpl(const StringCordView* str) + : IteratorImpl(str, 0, 0, str->size()) {} + + IteratorImpl() : IteratorImpl(nullptr, 0, 0, 0) {} + + IteratorImpl(const IteratorImpl&) = default; + IteratorImpl(IteratorImpl&&) = default; + IteratorImpl& operator=(const IteratorImpl&) = default; + IteratorImpl& operator=(IteratorImpl&&) = default; + + IteratorImpl& operator++() { + if (size_ == 0) { + return *this; + } + if ((pos_ + 1) < str_->pieces_[line_].size()) { + pos_++; + } else { + line_++; + pos_ = 0; + } + return *this; + } + + IteratorImpl operator++(int) { + IteratorImpl prev(*this); + ++(*this); + return prev; + } + + IteratorImpl next_iter() const { + IteratorImpl next(*this); + ++next; + return next; + } + + IteratorImpl& operator+=(size_t num); + + IteratorImpl operator+(size_t num) const { + IteratorImpl it(*this); + it += num; + return it; + } + + bool operator==(const IteratorImpl& rhs) const { + if (!has_next() && !rhs.has_next()) { + return true; + } + return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_); + } + + bool operator!=(const IteratorImpl& rhs) const { + return !((*this) == rhs); + } + bool has_next() const { + return size_ > 0 && (line_ < str_->pieces_.size()); + } + + char operator*() const { + TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size()); + TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size()); + return str_->pieces_[line_].at(pos_); + } + + // returns rest of the line of the current iterator + std::string_view rest_line() const { + if (line_ >= str_->pieces_.size()) { + return ""; + } + + std::string_view cur_line = str_->pieces_[line_]; + return cur_line.substr(pos_, std::string::npos); + } + + size_t pos() const { + if (size_ == 0) { + return 0; + } + return str_->accumulated_sizes_[line_] + pos_; + } + + private: + size_t line_; + size_t pos_; + const StringCordView* str_; + size_t size_; + friend struct StringCordView; + }; + + // Either an IteratorImpl, or a simple std::string_view::iterator + // (which is faster) if possible. + struct Iterator { + Iterator() = default; + + Iterator( + const StringCordView* str, + size_t start_line, + size_t start_pos, + size_t size) + : repr_( + str->pieces_.size() == 1 + ? repr_type(FastRepr( + start_line ? str->pieces_[0].end() + : str->pieces_[0].begin() + start_pos, + str)) + : repr_type(IteratorImpl(str, start_line, start_pos, size))) { + } + + Iterator(const StringCordView* str) : Iterator(str, 0, 0, str->size()) {} + + Iterator& operator++() { + if (auto* pit = std::get_if(&repr_)) { + ++(*pit); + } else { + ++fast_repr().it; + } + return *this; + } + + Iterator operator++(int) { + Iterator prev(*this); + ++(*this); + return prev; + } + + Iterator next_iter() const { + Iterator next(*this); + ++next; + return next; + } + + Iterator& operator+=(size_t num) { + if (auto* pit = std::get_if(&repr_)) { + *pit += num; + } else { + fast_repr().it += num; + } + return *this; + } + + Iterator operator+(size_t num) const { + Iterator it(*this); + it += num; + return it; + } + + bool operator==(const Iterator& rhs) const { + return repr_ == rhs.repr_; + } + + bool operator!=(const Iterator& rhs) const { + return repr_ != rhs.repr_; + } + + bool has_next() const { + if (const auto* pit = std::get_if(&repr_)) { + return pit->has_next(); + } else { + return fast_repr().it != fast_repr().str->pieces_[0].end(); + } + } + + char operator*() const { + if (const auto* pit = std::get_if(&repr_)) { + return **pit; + } else { + return *fast_repr().it; + } + } + + std::string_view rest_line() const { + if (const auto* pit = std::get_if(&repr_)) { + return pit->rest_line(); + } else { + // NOTE: std::string_view(it, end) ctor wasn't added until C++20. + const auto fast_repr_end = fast_repr().str->pieces_[0].end(); + if (fast_repr().it != fast_repr_end) { + return std::string_view( + &*fast_repr().it, fast_repr_end - fast_repr().it); + } + return std::string_view(); + } + } + + size_t pos() const { + if (const auto* pit = std::get_if(&repr_)) { + return pit->pos(); + } else { + return fast_repr().it - fast_repr().str->pieces_[0].begin(); + } + } + + private: + // When we have only one entry in pieces_ (importantly, such as + // when called from torch::Library::def during startup), we can + // skip extra complexity and just use string_view::iterator + // directly. + struct FastRepr { + std::string_view::iterator it; + const StringCordView* str; + + FastRepr() : str(nullptr) {} + + explicit FastRepr( + std::string_view::iterator it_, + const StringCordView* str_) + : it(it_), str(str_) {} + + bool operator==(const FastRepr& rhs) const { + return it == rhs.it && str == rhs.str; + } + + bool operator!=(const FastRepr& rhs) const { + return !operator==(rhs); + } + }; + using repr_type = std::variant; + repr_type repr_; + + FastRepr& fast_repr() { + // -Oz refuses to inline std::get. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(std::holds_alternative(repr_)); + return *std::get_if(&repr_); + } + + const FastRepr& fast_repr() const { + // -Oz refuses to inline std::get. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(std::holds_alternative(repr_)); + return *std::get_if(&repr_); + } + }; + + Iterator begin() const { + return Iterator(this, 0, 0, size()); + } + Iterator end() const { + return Iterator(this, pieces_.size(), 0, 0); + } + Iterator iter_for_pos(size_t pos) const; + + private: + IteratorImpl begin_impl() const { + return IteratorImpl(this, 0, 0, size()); + } + IteratorImpl end_impl() const { + return IteratorImpl(this, pieces_.size(), 0, 0); + } + IteratorImpl iter_impl_for_pos(size_t pos) const; + std::vector pieces_; + std::vector accumulated_sizes_; + std::vector> owned_strings_; +}; + +// Source represents a code segment. It keeps track of: +// - text_view : the view into text of the code segment +// - filename (optional) : if present, represents the name of the file from +// which the code segment originated. +// - starting_line_no : represents the line in the original file where the +// code segment started. +struct TORCH_API Source { + // Whether or not Source should copy the string passed in the constructor. + enum CopiesString { COPIES_STRING, DONT_COPY }; + + explicit Source( + std::string_view text_view, + std::optional filename = std::nullopt, + size_t starting_line_no = 0, + std::shared_ptr gen_ranges = nullptr, + CopiesString copies_str = COPIES_STRING) + : text_view_(create_text_view(copies_str, text_view)), + filename_(std::move(filename)), + starting_line_no_(starting_line_no), + gen_ranges_(std::move(gen_ranges)) { + calc_line_start_offsets(); + } + + explicit Source( + StringCordView str, + std::optional filename = std::nullopt, + size_t starting_line_no = 0, + std::shared_ptr gen_ranges = nullptr) + : text_view_(std::move(str)), + filename_(std::move(filename)), + starting_line_no_(starting_line_no), + gen_ranges_(std::move(gen_ranges)) { + calc_line_start_offsets(); + } + // Given a line number (within source_), return the byte offset of the + // beginning of that line. + size_t offset_for_line(size_t line) const { + return line_starting_offsets_.at(line); + } + + // Returns number of lines present. + size_t num_lines() const { + return line_starting_offsets_.size(); + } + + // Calculate the line (within the code segment) on which `offset` resides. + size_t lineno_for_offset(size_t offset) const { + auto iter = std::upper_bound( + line_starting_offsets_.begin(), line_starting_offsets_.end(), offset); + return iter - line_starting_offsets_.begin() - 1; + } + + // Calculate the line (within the original source file, if present) on which + // `lineno` resides. + size_t lineno_to_source_lineno(size_t lineno) const { + if (filename_) { + return lineno + starting_line_no_; + } else { + return lineno; + } + } + + StringCordView get_line(size_t lineno) const { + auto start = offset_for_line(lineno); + auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start + : text_view_.size() - start; + return text_view_.substr(start, size); + } + + const StringCordView& text_str() const { + return text_view_; + } + + char char_at(size_t index) const { + return text_view_.at(index); + } + + size_t size() const { + return text_view_.size(); + } + + std::optional& filename() { + return filename_; + } + + size_t starting_line_no() const { + return starting_line_no_; + } + + std::optional findSourceRangeThatGenerated( + const SourceRange& range); + + ~Source() = default; + + private: + void calc_line_start_offsets() { + line_starting_offsets_.clear(); + line_starting_offsets_.push_back(0); + size_t pos = 0; + while ((pos = text_view_.find("\n", pos)) != std::string::npos) { + line_starting_offsets_.push_back(++pos); + } + } + + static StringCordView create_text_view( + CopiesString copies_str, + std::string_view text_view) { + if (copies_str == COPIES_STRING) { + auto allocated_str = + std::make_shared(text_view.data(), text_view.size()); + return StringCordView({*allocated_str}, {allocated_str}); + } else { + return StringCordView({text_view}, {}); + } + } + + StringCordView text_view_; + + std::optional filename_; + // If filename_ is not present, starting_line_no_ is don't care + size_t starting_line_no_; + // Starting offsets for lines into the source. e.g. line 0 starts at + // line_starting_offsets_[0], etc. + std::vector line_starting_offsets_; + + std::shared_ptr gen_ranges_; +}; + +// A SourceRange is a reference to subset of a Source, specified by `start` and +// `end` byte offsets into the source text. +struct TORCH_API SourceRange { + SourceRange(std::shared_ptr source_view, size_t start_, size_t end_) + : source_view_(std::move(source_view)), start_(start_), end_(end_) { + if (source_view_) { + start_iter_ = source_view_->text_str().iter_for_pos(start_); + } + } + + SourceRange() : source_view_(nullptr), start_(0), end_(0) {} + + SourceRange( + std::shared_ptr source_view_, + StringCordView::Iterator start_iter, + size_t end_) + : source_view_(std::move(source_view_)), + start_(start_iter.pos()), + end_(end_), + start_iter_(start_iter) {} + + const std::string_view token_text() const { + size_t size = end() - start(); + return start_iter_.rest_line().substr(0, size); + } + + const StringCordView text() const { + return source_view_->text_str().substr(start(), end() - start()); + } + size_t size() const { + return end() - start(); + } + static const size_t CONTEXT = 3; + void highlight(std::ostream& out) const; + + // Customizable version of 'highlight' method. + void print_with_context( + std::ostream& out, + size_t context, + bool highlight, + const std::string& funcname) const; + + const std::shared_ptr& source() const { + return source_view_; + } + size_t start() const { + return start_; + } + size_t end() const { + return end_; + } + std::string str() const { + std::stringstream ss; + highlight(ss); + return ss.str(); + } + + std::optional> file_line_col() const { + if (!source_view_ || !source()->filename()) { + return std::nullopt; + } + + auto lineno = source_view_->lineno_for_offset(start_); + auto col_offset = (int)start_ - (int)source_view_->offset_for_line(lineno); + // TODO: std::optional<>::value returns an rvalue ref so can't use it here?? + return std::make_tuple( + source_view_->filename().value_or(""), + source_view_->lineno_to_source_lineno(lineno), + (size_t)col_offset); + } + + bool operator==(const SourceRange& rhs) const { + return start() == rhs.start() && end() == rhs.end() && + source() == rhs.source(); + } + + bool operator!=(const SourceRange& rhs) const { + return !(*this == rhs); + } + + std::optional findSourceRangeThatGenerated() const { + if (!source_view_) { + return std::nullopt; + } + return source_view_->findSourceRangeThatGenerated(*this); + } + + protected: + std::shared_ptr source_view_; + + private: + size_t start_; + size_t end_; + StringCordView::Iterator start_iter_; +}; + +// OwnedSourceRange is just like a SourceRange except that it owns a `Source` +// instead of `Source`. Thus OwnedSourceRange owns a copy of source text. +struct OwnedSourceRange : public SourceRange { + explicit OwnedSourceRange(const SourceRange& source_range) + : SourceRange(source_range) { + const auto& source = source_range.source(); + if (source) { + source_view_ = std::make_shared( + source->text_str().str(), + source->filename(), + source->starting_line_no()); + } + } +}; + +struct TORCH_API SourceRangeHasher { + public: + size_t operator()(const torch::jit::SourceRange& key) const; +}; + +struct StackEntry { + std::string filename; + SourceRange range; +}; + +TORCH_API void format_stack_trace( + std::ostream& out, + const std::vector& entries); + +inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) { + range.highlight(out); + return out; +} + +// A pair of (byte offset, SourceRange) describing a specific segment +// of the output stream +struct TaggedRange { + TaggedRange(size_t bytes, SourceRange range) + : bytes(bytes), range(std::move(range)) {} + size_t bytes; + SourceRange range; +}; +using SourceRangeRecords = std::vector; +using SourceRangeTagMap = + std::unordered_map; + +} // namespace torch::jit + +namespace std { +template <> +struct iterator_traits { + using value_type = char; + using difference_type = ptrdiff_t; + using pointer = char*; + using reference = char&; + using iterator_category = std::forward_iterator_tag; +}; +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/source_ref.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/source_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..c9ea38fa777503fe5a5b4bd8af382e799c651659 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/source_ref.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::jit { + +/** + * SourceRef does two things: + * 1. Owns a Source object. + * 2. Serves as lookup key to the owned Source in associative containers, for + * runtime data aggregation. + * We don't want to use std::shared_ptr directly because we want to + * support heteogeneous lookup, and also shared_ptr is an implementation detail + * which should be encapsulated. + */ +class TORCH_API SourceRef : public CustomClassHolder { + public: + explicit SourceRef(std::shared_ptr source_view) + : source_view_(std::move(source_view)) {} + bool operator==(const SourceRef& other) const { + return source_view_ == other.source_view_; + } + bool operator<(const Source& other) const { + return source_view_.get() < &other; + } + friend bool operator<(const Source& other, const SourceRef& self) { + return &other < self.source_view_.get(); + } + bool operator<(const SourceRef& other) const { + return *this < *other.source_view_; + } + const Source* operator->() const { + return source_view_.get(); + } + + private: + std::shared_ptr source_view_; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/strtod.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/strtod.h new file mode 100644 index 0000000000000000000000000000000000000000..eb704a3e689ef79cb5977f59fe28a4b68b6eec72 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/strtod.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace torch::jit { + +TORCH_API double strtod_c(const char* nptr, char** endptr); +TORCH_API float strtof_c(const char* nptr, char** endptr); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/sugared_value.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/sugared_value.h new file mode 100644 index 0000000000000000000000000000000000000000..04ba980bb4e16947986a0eabf5f3fa24fe7869a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/sugared_value.h @@ -0,0 +1,861 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +using SugaredValuePtr = std::shared_ptr; + +// The AST can contain nodes like `self`, `self.b` or `python_fn` that +// are not first-class values in the graph representation, but instead +// will be desugared based on how they are used in the AST. + +// SugaredValue is used to temporarily represent these values in a way +// that separates their behavior from the AST -> IR converter itself. +// This allows us to keep dependencies on python minimal. + +struct TORCH_API SugaredValue + : public std::enable_shared_from_this { + // what is this node? for error reporting (e.g. Module, python function) + virtual std::string kind() const = 0; + + // what can we do with this thing? + // use it as a value e.g. `this + 4` + virtual Value* asValue(const SourceRange& loc, GraphFunction& m) { + throw(ErrorReport(loc) << kind() << " cannot be used as a value"); + } + + // select an attribute on it, e.g. `this.field` + virtual std::shared_ptr attr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field) { + throw(ErrorReport(loc) << "attribute lookup is not defined on " << kind()); + } + + virtual bool hasAttr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field) { + throw(ErrorReport(loc) << "attribute lookup is not defined on " << kind()); + } + + // assign an attribute on it, e.g. `this.field = newValue` + virtual void setAttr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field, + Value* newValue) { + throw( + ErrorReport(loc) << "attribute assignment is not defined on " + << kind()); + } + + // use it as a vector of values, e.g. a tuple of values as return value from + // a method invocation + virtual std::vector> asTuple( + const SourceRange& loc, + GraphFunction& m, + const std::optional& size_hint = {}) { + throw(ErrorReport(loc) << kind() << " cannot be used as a tuple"); + } + + // TODO @wconstab refactor to use ModuleValue::asTuple instead of new API + virtual SugaredValuePtr asTupleValue( + const SourceRange& loc, + GraphFunction& m) { + throw(ErrorReport(loc) << kind() << " cannot be used as a tuplevalue"); + } + + virtual std::vector> asType( + const SourceRange& loc, + Method& m) { + throw(ErrorReport(loc) << kind() << " cannot be used as a type"); + } + + // call it like a function, e.g. `outputs = this(inputs)` + virtual std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + // note: names for args will be 'argument 0', 'argument 1', etc.. + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) { + // n_binders is always set to the number of variables an expression is + // syntactically bound to: + // a = foo() # 1 binder (note in this case the single binder might be a + // tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0 + // binders + // + // In subexpressions, like bar() in foo(bar()), n_binders is always set to + // 1. n_binders is used as a hint to subexpressions to determine how many + // values they should return when that number is ambiguous statically. In + // particular it is currently used to decide how many tensors a call to a + // python function will return. It is only a hint, functions do not have to + // check that n_binders match the number of things they are returning, the + // assignment logic will do that anyway. + + throw(ErrorReport(loc) << "cannot call a " << kind()); + } + + // This function is called when to convert a SugaredValue to its iterator. + // For example, when iterating through a Dict we iterate over its keys + virtual std::shared_ptr iter( + const SourceRange& loc, + GraphFunction& m) { + throw(ErrorReport(loc) << kind() << " cannot be used as an iterable"); + } + + // If we are iterating over a Sugared Value and it returns a value from this + // function, then we emit an unrolled loop over the variable. This allows us + // to support containers of Heterogenous types, like Module Containers & + // Tuples + virtual std::optional staticLen() { + return std::nullopt; + } + + // When iterating over this SugaredValue, should we emit the for loop as an + // unrolled loop. + bool shouldEmitUnrolled() { + return staticLen() != std::nullopt; + } + + // return length of this thing, if not then it can't be iterated. + // If it does not have a statically-determinable length, then it cannot + // be iterated over with a modulelist. If it does it must return a constant + // Value * + virtual Value* len(const SourceRange& loc, GraphFunction& m) { + throw( + ErrorReport(loc) << "'" << kind() << "'" + << " object is not iterable"); + } + + // expression for ith elemement for iterable value + virtual std::shared_ptr getitem( + const SourceRange& loc, + GraphFunction& m, + Value* idx, + TypePtr type_hint = nullptr) { + throw( + ErrorReport(loc) << "'" << kind() << "'" + << " object is not subscriptable"); + } + + virtual ~SugaredValue() = default; +}; + +// most things in the environment are just simple value types +// and not special python syntax sugar types +struct TORCH_API SimpleValue : public SugaredValue { + SimpleValue(Value* value) : value_(value) {} + std::string kind() const override { + std::stringstream ss; + // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) + ss << "value of type '" << value_->type()->annotation_str() << "'"; + return ss.str(); + } + Value* asValue(const SourceRange& range, GraphFunction& m) override { + return value_; + } + std::vector> asTuple( + const SourceRange& loc, + GraphFunction& m, + const std::optional& size_hint = {}) override; + std::shared_ptr attr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field) override; + + bool hasAttr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field) override; + + void setAttr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field, + Value* newValue) override; + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + // note: names for args will be 'argument 0', 'argument 1', etc.. + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; + + std::shared_ptr iter(const SourceRange& loc, GraphFunction& m) + override; + + Value* getValue() const { + return value_; + } + + Value* len(const SourceRange& loc, GraphFunction& m) override; + SugaredValuePtr getitem( + const SourceRange& loc, + GraphFunction& m, + Value* idx, + TypePtr type_hint = nullptr) override; + + private: + Value* value_; +}; + +struct TORCH_API BuiltinFunction : public SugaredValue { + BuiltinFunction(Symbol symbol, std::optional self) + : symbol(symbol), self(std::move(self)) {} + + // The symbol of the function (e.g. `aten::relu`). + Symbol symbol; + + // if this is method, then this is the self argument. + std::optional self; + std::string kind() const override { + return "builtin"; + } + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; + + // try to create this builtin but if it doesn't exist or the self argument + // cannot possibly match, then return nullptr. Use in situations where it is + // not clear if it is a valid builtin + static std::shared_ptr tryCreate( + Symbol symbol, + std::optional self); +}; + +struct TORCH_API SugaredTupleValue : public SugaredValue { + explicit SugaredTupleValue(std::vector> tup) + : tup_(std::move(tup)) {} + + std::vector> asTuple( + const SourceRange& loc, + GraphFunction& m, + const std::optional& size_hint = {}) override { + return tup_; + } + + Value* asValue(const SourceRange& loc, GraphFunction& m) override { + std::vector vec; + vec.reserve(tup_.size()); + for (const auto& sv : tup_) { + vec.push_back(sv->asValue(loc, m)); + } + Graph& g = *m.graph(); + return g.insertNode(g.createTuple(vec))->output(); + } + + std::string kind() const override { + return "Tuple"; + } + + SugaredValuePtr getitem( + const SourceRange& loc, + GraphFunction& m, + Value* idx, + TypePtr type_hint = nullptr) override { + if (!(idx->type()->cast() && toIValue(idx))) { + throw( + ErrorReport(loc) + << "Expected integer literal for index but got a variable or non-integer. " + << "ModuleList/Sequential indexing is only supported with integer literals. " + << "For example, 'i = 4; self.layers[i](x)' will fail because i is not a literal. " + << "Enumeration is supported, e.g. 'for index, v in enumerate(self): out = v(inp)'"); + } + auto index = toIValue(idx)->toInt(); + int64_t adj_index = + (index < 0) ? index + static_cast(tup_.size()) : index; + if (!(adj_index >= 0 && adj_index < static_cast(tup_.size()))) { + throw( + ErrorReport(loc) << "Index " << index << " out of range of length " + << tup_.size()); + } + return tup_.at(adj_index); + } + + // This function is called when a SugaredValue is used to convert a + // SugaredValue to its iterator. For example, when iterating through a Dict we + // iterate over its keys + std::shared_ptr iter(const SourceRange& loc, GraphFunction& m) + override { + return shared_from_this(); + } + + // Because this is used to contain SugaredValues of Heterogenous types, + // we define staticLen() so that when this is iterated over it is emitted + // as an unrolled loop. + std::optional staticLen() override { + return static_cast(tup_.size()); + } + + std::vector> tup_; +}; + +struct TORCH_API BuiltinModule : public SugaredValue { + BuiltinModule(std::string name, std::optional version = std::nullopt) + : name(std::move(name)), version(version) {} + + std::string kind() const override { + return "builtin module"; + } + std::shared_ptr attr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field) override { + if (field == "autograd") { + // When refering torch.autograd, it is also considered to be a + // BuiltinModule and we will dispatch to the aten operators for the + // methods under its module. + return std::make_shared("aten", version); + } + + auto sym = Symbol::fromQualString(name + "::" + field); + return std::make_shared(sym, std::nullopt); + } + + private: + std::string name; + // when we add operator versioning, emit this op as it exising at 'version' + // if not set, use the latest version + std::optional version; +}; + +// Represents a class, analagous to `int` or `dict`. Instances of classes, +// like `1` or `{"foo": 5}`, are represented as SimpleValues +struct TORCH_API ClassValue : public SugaredValue { + explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {} + + // Call the type's constructor, as in: + // n = Foo(constructor_arg) + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; + + std::shared_ptr attr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field) override; + + std::string kind() const override { + return type_->str(); + } + + ClassTypePtr type_; +}; + +struct TORCH_API NamedTupleConstructor : public SugaredValue { + explicit NamedTupleConstructor(TupleTypePtr type) : type_(std::move(type)) {} + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; + + std::string kind() const override { + return type_->str(); + } + + TupleTypePtr type_; +}; + +struct FunctionValue : public SugaredValue { + FunctionValue(Function* callee) : callees_({callee}) {} + FunctionValue(const StrongFunctionPtr& p) + : callees_({p.function_}), cu_(p.cu_) {} + FunctionValue(const std::vector& callees) { + for (const StrongFunctionPtr& callee : callees) { + cu_ = cu_ ? cu_ : callee.cu_; + TORCH_INTERNAL_ASSERT(callee.cu_ == cu_); + callees_.push_back(callee.function_); + } + } + + std::string kind() const override { + return "function"; + } + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& f, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override { + std::vector schemas; + for (Function* callee : callees_) { + try { + callee->ensure_defined(); + } catch (const RecursiveMethodCallError&) { + throw( + ErrorReport(loc) + << " function '" << callee->name() << "' is called recursively. " + << "Recursive calls are not supported"); + } + schemas.push_back(&callee->getSchema()); + } + auto match = matchSchemas(schemas, loc, *f.graph(), args, kwargs); + Value* output = + f.graph()->insertFunctionCall(callees_[match.first], match.second); + output->node()->setSourceRange(loc); + return std::make_shared(output); + } + + const std::vector& callees() { + return callees_; + } + + private: + std::vector callees_; + // TODO holding this thing is creepy + std::shared_ptr cu_; +}; + +struct TORCH_API ClosureValue : public SugaredValue { + ClosureValue(Value* value) : value_(value) { + TORCH_INTERNAL_ASSERT(value_->node()->kind() == prim::Closure); + } + std::string kind() const override { + return "closure"; + } + Value* asValue(const SourceRange& range, GraphFunction& m) override { + return value_; + } + Value* value_; +}; + +// defines how a method obtained from a module/class/interface behaves in script +struct MethodValue : public SugaredValue { + MethodValue(Value* self, std::vector method_names) + : self_(self), method_names_(std::move(method_names)) {} + MethodValue(Value* self, std::string method_name) + : MethodValue(self, std::vector({std::move(method_name)})) {} + + std::string kind() const override { + return "method"; + } + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& f, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override { + std::vector argsWithSelf = {self_}; + argsWithSelf.insert(argsWithSelf.end(), args.begin(), args.end()); + std::vector schemas; + for (const std::string& method_name : method_names_) { + if (auto class_type = self_->type()->cast()) { + Function& method = class_type->getMethod(method_name); + try { + method.ensure_defined(); + } catch (const RecursiveMethodCallError&) { + throw( + ErrorReport(loc) + << " method '" << method.name() << "' is called recursively. " + << "Recursive calls are not supported"); + } + schemas.push_back(&method.getSchema()); + } else if (auto interface_type = self_->type()->cast()) { + schemas.push_back(interface_type->getMethod(method_name)); + } else { + TORCH_INTERNAL_ASSERT( + false, "method constructed that is not a class or interface"); + } + } + auto match = matchSchemas(schemas, loc, *f.graph(), argsWithSelf, kwargs); + Value* output = + f.graph()->insertMethodCall(method_names_[match.first], match.second); + output->node()->setSourceRange(loc); + return std::make_shared(output); + } + + private: + Value* self_; + std::vector method_names_; +}; + +struct TORCH_API PrintValue : public SugaredValue { + std::string kind() const override { + return "print"; + } + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; +}; + +// expressions like int(x) +// these are the same as call prim::Int or equivalent except it +// is a noop when the input is a subtype of 'type' +struct TORCH_API CastValue : public BuiltinFunction { + CastValue(TypePtr type, c10::Symbol method) + : BuiltinFunction(method, std::nullopt), type_(std::move(type)) {} + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override { + if (args.size() == 1 && kwargs.empty()) { + auto len_op = std::make_shared(aten::len, std::nullopt); + auto gt_op = std::make_shared(aten::gt, std::nullopt); + auto zero = m.graph()->insertConstant(0); + + auto v = args[0].value(*m.graph()); + if (v->type()->isSubtypeOf(*type_)) { + return std::make_shared(v); + } else if ( + *type_ == *BoolType::get() && + (v->type()->isSubtypeOf(*AnyListType::get()) || + v->type()->isSubtypeOf(*StringType::get()) || + v->type()->cast())) { + auto len = len_op->call(loc, m, {v}, {}, 1); + return gt_op->call(loc, m, {len->asValue(loc, m), zero}, {}, 1); + } + } + return BuiltinFunction::call(loc, m, args, kwargs, n_binders); + } + + private: + TypePtr type_; +}; + +struct TORCH_API TensorCastValue : public SugaredValue { + TensorCastValue(at::ScalarType type, NamedValue self) + : dtype_(type), self_(std::move(self)) {} + + std::string kind() const override { + return "Cast"; + } + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override { + TORCH_INTERNAL_ASSERT(args.empty() && kwargs.empty()); + Value* dtype_const = m.graph()->insertConstant(dtype_, loc); + std::vector kwargs_{ + self_, NamedValue(loc, "dtype", dtype_const)}; + Value* casted_val = m.graph()->insert( + /*opname=*/Symbol::fromQualString("aten::to"), + /*args=*/args, + /*kwargs=*/kwargs_, + /*range=*/loc); + return std::make_shared(casted_val); + } + + at::ScalarType dtype_; + NamedValue self_; +}; + +// builtins operators and functions that call a method if it exists +// on a class type, like 'len(x)' and 'x + y' +struct TORCH_API MagicMethod : public SugaredValue { + MagicMethod(std::string desugared_name, SugaredValuePtr base) + : base_value_(std::move(base)), + desugared_name_(std::move(desugared_name)) {} + + std::string kind() const override { + return desugared_name_; + } + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef kwargs, + size_t n_binders) override; + + private: + SugaredValuePtr base_value_; + std::string desugared_name_; +}; + +// things that look like function applications, but +// perform non-standard evaluation are represented +// with SpecialFormValues, e.g. +// isinstance(x, int) +// fork(fn) +// annotate(int, 3) +// The implementation of each value is handled by a case inside emitApplyExpr +struct TORCH_API SpecialFormValue : public SugaredValue { + SpecialFormValue(Symbol form) : form_(form) {} + std::string kind() const override { + return form_.toUnqualString(); + } + Symbol form() const { + return form_; + } + static std::shared_ptr create(Symbol form) { + return std::make_shared(form); + } + + private: + Symbol form_; +}; + +struct TORCH_API LegacyTensorConstructor : public SpecialFormValue { + LegacyTensorConstructor(Symbol form, at::ScalarType dtype, at::Device device) + : SpecialFormValue(form), device_(device), dtype_(dtype) {} + + static std::shared_ptr create( + Symbol form, + at::ScalarType dtype, + at::Device device) { + return std::make_shared(form, dtype, device); + } + at::ScalarType dtype() const { + return dtype_; + } + + private: + at::Device device_; + at::ScalarType dtype_; +}; + +// matched against for special handling of range expressions +struct TORCH_API RangeValue : SugaredValue { + RangeValue( + const SourceRange& loc, + GraphFunction& m, + std::vector input, + std::optional static_len = std::nullopt); + + std::string kind() const override { + return "range"; + } + Value* len(const SourceRange& loc, GraphFunction& m) override; + SugaredValuePtr getitem( + const SourceRange& loc, + GraphFunction& m, + Value* idx, + TypePtr type_hint = nullptr) override; + std::shared_ptr iter(const SourceRange& loc, GraphFunction& m) + override; + + // When Range is instantiated via enumerate(iterable_with_static_len), + // then it takes the static length of the iterable + std::optional staticLen() override { + return static_len_; + } + + private: + Value* start_{}; + Value* end_{}; + Value* step_{}; + // a flag to determine if it's a simple range() call with only end_ from + // arguments If true, we will not insert length calculation and index + // derivation nodes to simplify the graph and enable more possible + // optimizations + bool has_only_end_{}; + std::optional static_len_; +}; + +// Specialized Tree structure to matched against for special handling +// of builtin functions iterables expressions like zip(), enumerate(), etc. +// zip and enumerate can be modeled as a tree of SimpleValue/RangeValue: +// zip(x, y) -> (x, y) with tuple assignment to each loop target +// enumerate(x) -> (range(0, math.inf, 1), x) +// So a complicated expression like zip(a, enumerate(b), range(0, 100)) will be: +// (a, (range(0, math.inf, 1), b), range(0, 100)) +// We use those base iterables to fill in the loop information like +// max_trip_count and set the value table for loop targets +// Iterables can contain lists of SugaredValues like ModuleLists. If it +// does, then we emit it unrolled and require that all values it contains +// have a statically-determinable length. +struct TORCH_API IterableTree : SugaredValue { + IterableTree() = default; + IterableTree( + const SourceRange& range, + GraphFunction& m, + at::ArrayRef children) { + for (const auto& child : children) { + addChild(range, m, child); + } + } + std::string kind() const override { + return "iterabletree"; + } + + std::shared_ptr iter(const SourceRange& loc, GraphFunction& m) + override { + return shared_from_this(); + } + + void addChild( + const SourceRange& range, + GraphFunction& m, + const SugaredValuePtr& iter_value); + + std::vector get_children() { + return children_; + } + + // If this iterable contains a ModuleList or Tuple, then it will have a + // static length, and we will emit it as an unrolled for loop. + std::optional staticLen() override { + return unroll_length_; + } + + // given a IterableTree node, get all the base iterables/leaves under the + // IterableTree node. This enables + // us to get all the basic SugaredValues that contains valid loop information + // with len() and getitem() + std::vector get_base_iterables(); + + Value* len(const SourceRange& loc, GraphFunction& m) override; + SugaredValuePtr getitem( + const SourceRange& loc, + GraphFunction& m, + Value* idx, + TypePtr type_hint = nullptr) override; + + private: + std::optional unroll_length_ = std::nullopt; + std::vector children_; +}; + +static inline std::vector toValues( + Graph& g, + at::ArrayRef nvs) { + return fmap(nvs, [&](const NamedValue& v) { return v.value(g); }); +} + +struct SimpleSelf : public Self { + explicit SimpleSelf(ClassTypePtr classType) + : Self(), classType_(std::move(classType)) {} + std::shared_ptr makeSugared(Value* v) const override { + v->setType(classType_); + return std::make_shared(v); + } + ClassTypePtr getClassType() const override { + return classType_; + } + + private: + ClassTypePtr classType_; +}; + +// This is not a SimpleValue so it can not pass through the code paths that +// expect a SimpleValue as a sugared value. +struct TORCH_API ExceptionMessageValue : public SugaredValue { + explicit ExceptionMessageValue( + Value* value, + Value* qualified_class_name = nullptr) + : value_(value), qualified_class_name_(qualified_class_name) {} + + std::string kind() const override { + return "exception message"; + } + + Value* getValue() { + return value_; + } + + // qualified python class name + Value* getQualifiedClassName() { + return qualified_class_name_; + } + + private: + Value* value_; + Value* qualified_class_name_; +}; + +struct TORCH_API ExceptionValue : public SugaredValue { + explicit ExceptionValue(std::string message) : message_(std::move(message)) {} + + std::string kind() const override { + return "exception"; + } + + std::shared_ptr call( + const SourceRange& loc, + GraphFunction& m, + at::ArrayRef args, + at::ArrayRef /*attributes*/, + size_t /*n_binders*/) override { + auto exception_message = insertConstant(*m.graph(), message_ + ": ", loc); + for (auto& input : args) { + auto input_str = input.value(*m.graph()); + if (!input_str->type()->isSubtypeOf(*StringType::get())) { + input_str = + emitBuiltinCall(loc, *m.graph(), aten::str, {input_str}, {}); + } + exception_message = emitBuiltinCall( + loc, *m.graph(), aten::add, {exception_message, input_str}, {}); + } + return std::make_shared(exception_message); + } + + std::string message_; +}; + +struct TORCH_API SugaredEnumClass : public SugaredValue { + explicit SugaredEnumClass(EnumTypePtr enum_type) + : enum_type_(std::move(enum_type)) {} + + std::string kind() const override { + return "EnumClass"; + } + + SugaredValuePtr attr( + const SourceRange& loc, + GraphFunction& m, + const std::string& field) override; + + SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override; + + private: + EnumTypePtr enum_type_; +}; + +struct TORCH_API SliceValue : public SugaredValue { + explicit SliceValue(Value* start, Value* stop, Value* step) + : start_(start), stop_(stop), step_(step) {} + + std::string kind() const override { + return "Python slice value"; + } + + Value* start() { + return start_; + } + Value* stop() { + return stop_; + } + Value* step() { + return step_; + } + + private: + Value* start_; + Value* stop_; + Value* step_; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tracer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tracer.h new file mode 100644 index 0000000000000000000000000000000000000000..dbfc6faa88c4038c7b94181afad55db8942cc8f4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tracer.h @@ -0,0 +1,413 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace torch::jit { +struct Node; +struct Value; +struct Graph; +struct Module; + +namespace tracer { + +using ::c10::ivalue::Shared; + +using ::c10::IValue; +using ::c10::ivalue::Future; + +using ::c10::ArrayRef; +using ::c10::TupleType; +using ::c10::TupleTypePtr; +using ::c10::ivalue::ConstantString; + +using torch::autograd::Variable; +using variable_list = std::vector; + +TORCH_API std::atomic& getTracerStateWarnMode(); + +struct TORCH_API TracingState + : public std::enable_shared_from_this { + TracingState(); + ~TracingState(); + + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::shared_ptr graph; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool warn = getTracerStateWarnMode(); + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool strict = true; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool force_outplace = false; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::function lookup_var_name_fn = + [](const Variable& var) { return ""; }; + + void enterFrame() { + env_stack.emplace_back(); + } + + void leaveFrame() { + env_stack.pop_back(); + } + + void setValue(const IValue& v, Value* value); + void delValue(const IValue& var); + Value* getValue(const IValue& var); + Value* getOutput(const IValue& var, size_t i); + bool hasValue(const IValue& var) const; + + Node* createNode(c10::Symbol op_name, size_t num_outputs); + void insertNode(Node* node); + + private: + using WeakIValue = at::WeakIValue; + + struct WeakIValueHasher { + size_t operator()(const WeakIValue& t) const { + return t.hash(); + } + }; + + struct WeakIValueEq { + bool operator()(const WeakIValue& t1, const WeakIValue& t2) const { + return t1.isSameIdentity(t2); + } + }; + + using Frame = + std::unordered_map; + std::vector env_stack; +}; + +// This is meant to be used as a thread local place, where we can store extra +// info that gets lost when we call into ATen from Python bindings. One example +// for when this happens is when we get an IntArrayRef argument with e.g. sizes +// for view. When tracing, those might be tensors, which let us encode extra +// data dependencies, but once they get to the ATen call where we actually have +// the tracing logic, they get converted into a raw IntArrayRef, and we loose +// all information. To prevent this, we temporarily stash it in here. +struct ArgumentStash { + struct IntArrayRefTrace : std::vector { + IntArrayRefTrace(size_t size) : std::vector(size, nullptr) {} + }; + + static bool empty() { + return stash.intlists.empty(); + } + + TORCH_API static void stashIntArrayRefElem( + const std::string& arg_name, + size_t size, + size_t idx, + const Variable& var); + + static bool hasIntArrayRef(const std::string& arg_name) { + return stash.intlists.count(arg_name) > 0; + } + + static IntArrayRefTrace popIntArrayRef(const std::string& arg_name) { + auto info = std::move(stash.intlists.at(arg_name)); + stash.intlists.erase(arg_name); + return info; + } + + // Value stashing: Use these methods to stash arguments which correspond + // to regular Value*'s in the graph. i.e. they don't require special + // handling like in the case of IntArrayRefs + TORCH_API static void stashValue( + const std::string& arg_name, + size_t idx, + const Variable& var, + const c10::TypePtr& type = nullptr); + + static bool hasValue(const std::string& arg_name) { + return stash.values.count(arg_name) > 0; + } + + static Value* popValue(const std::string& arg_name) { + auto info = stash.values.at(arg_name); + stash.values.erase(arg_name); + return info; + } + + private: + static thread_local ArgumentStash stash; + std::unordered_map intlists; + std::unordered_map values; +}; + +// Retrieve or set the current tracing state. Returns a nullptr if tracing is +// disabled. +TORCH_API const std::shared_ptr& getTracingState(); +TORCH_API void setTracingState(std::shared_ptr state); + +inline bool isTracing() { + return static_cast(getTracingState()); +} + +using warn_fn_type = void (*)(const std::string& msg); +TORCH_API extern const char* WARN_PYTHON_DATAFLOW; +TORCH_API extern const char* WARN_CONSTRUCTOR; +TORCH_API extern const char* WARN_RESIZE; +TORCH_API extern const char* STRICT_TRACER_MSG; +TORCH_API void _do_warn(const char* _reason, const char* _kind); +inline void warn(const char* _reason, const char* _kind = nullptr) { + if (const auto& state = getTracingState()) { + if (!state->warn) + return; + _do_warn(_reason, _kind); + } +} +TORCH_API void setWarn(warn_fn_type fn); + +struct TORCH_API NoWarn { + NoWarn() : state(getTracingState()) { + if (state) { + prev = state->warn; + state->warn = false; + } + } + ~NoWarn() { + if (state) { + state->warn = prev; + } + } + std::shared_ptr state; + bool prev{false}; +}; + +struct WithNestedTracingFrame { + WithNestedTracingFrame() { + getTracingState()->enterFrame(); + } + + ~WithNestedTracingFrame() { + getTracingState()->leaveFrame(); + } +}; +TORCH_API void recordSourceLocation(Node* n); +TORCH_API void setRecordSourceLocation(void (*v)(Node*)); + +TORCH_API std::vector pythonCallstack(); +TORCH_API void setPythonCallstack(std::vector (*v)()); + +// Having finished adding a new 'node' to the graph IR 'setValueTrace' +// associates this node with an output variable, so that further operations +// involving this variable know which node in the IR to reference. +TORCH_API void setValueTrace(const IValue& v, Value* value); + +TORCH_API void delValueTrace(const IValue& var); + +TORCH_API std::function pauseTracing(); + +TORCH_API Value* getValueTrace(const IValue& var); + +TORCH_API std::pair, Stack> trace( + Stack inputs, + const std::function& traced_fn, + std::function var_name_lookup_fn, + bool strict = true, + bool force_outplace = false, + Module* self = nullptr, + const std::vector& argument_names = {}); + +TORCH_API void abandon(); + +// NB: those serve both as an intermediate steps in addInputs below, +// as well as the overloads that terminate template recursion +TORCH_API void addInputs(Node* n, const char* name, int64_t value); +TORCH_API void addInputs(Node* n, const char* name, const c10::SymInt& value); +TORCH_API void addInputs( + Node* n, + const char* name, + std::optional value); +TORCH_API void addInputs(Node* n, const char* name, bool value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs(Node* n, const char* name, double value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs(Node* n, const char* name, ArrayRef value); +TORCH_API void addInputs(Node* n, const char* name, c10::SymIntArrayRef value); +TORCH_API void addInputs( + Node* n, + const char* name, + std::optional value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional>& value); +TORCH_API void addInputs( + Node* n, + const char* name, + const at::OptionalIntArrayRef& opt_value); +TORCH_API void addInputs( + Node* n, + const char* name, + const at::OptionalSymIntArrayRef& opt_value); +TORCH_API void addInputs( + Node* n, + const char* name, + ArrayRef value, + bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::vector& value, + bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + at::ITensorListRef value, + bool allow_undefined = false); +TORCH_API void addInputs( + Node* n, + const char* name, + const List>& value); +TORCH_API void addInputs( + Node* n, + const char* name, + ArrayRef> value, + const c10::ClassTypePtr& class_type); +TORCH_API void addInputs(Node* n, const char* name, ArrayRef value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional>& value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::string_view value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs(Node* n, const char* name, at::Device value); +TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream); +TORCH_API void addInputs(Node* n, const char* name, at::Layout value); +TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value); +TORCH_API void addInputs( + Node* n, + const char* name, + std::optional value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); +TORCH_API void addInputs( + Node* n, + const char* name, + const std::optional& value); + +inline void addInputs( + Node* n, + const char* name, + const std::vector& value) { + TORCH_CHECK(false, "Tracing a list of bool type is currently not supported!"); +} + +template +void addInputs(Node* n, const char* name, ArrayRef value) { + TORCH_CHECK( + false, "Tracing a list of arbitrary type is currently not supported!"); +} +template +void addInputs( + Node* n, + const char* name, + const std::unordered_map& value) { + TORCH_CHECK( + false, "Tracing a dict of arbitrary types is currently not supported!"); +} + +template +void addInputs(Node* n, const char* name, std::array value) { + throw std::runtime_error( + "Found an unsupported argument type in the JIT tracer. File a bug report."); +} + +TORCH_API void addInputs( + Node* n, + const char* name, + const c10::intrusive_ptr& obj); + +TORCH_API void ensureUniqueIfOutOfPlaced( + const char* name, + const at::Tensor& tensor); +TORCH_API void ensureUniqueIfOutOfPlaced( + const char* name, + const std::optional& tensor); + +template < + typename T, + typename = std::enable_if_t< + (!std::is_convertible_v, at::TensorList> && + !std::is_convertible_v, c10::List> && + !std::is_convertible_v, at::Tensor> && + !std::is_convertible_v< + std::decay_t, + c10::intrusive_ptr>)>> +void addOutput(Node* node, T&&) { + TORCH_CHECK( + false, + "Found an unsupported argument type ", + c10::demangle_type(), + " in the JIT tracer. File a bug report."); +} +TORCH_API void addOutput(Node* node, const at::Tensor& tensor); +TORCH_API void setOutput(Value* value, const at::Tensor& output); +TORCH_API void addOutput(Node* node, const std::vector& list); +TORCH_API void addOutput(Node* node, const c10::List& list); +TORCH_API void addOutput( + Node* node, + const c10::intrusive_ptr& output); + +TORCH_API autograd::Variable getSizeOf( + const autograd::Variable& var, + int64_t dim); + +TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var); + +} // namespace tracer +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tree.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tree.h new file mode 100644 index 0000000000000000000000000000000000000000..84e5e7755fef798c9dc107c0d524c73dd110693e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tree.h @@ -0,0 +1,218 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::jit { + +// Trees are used to represent all forms of TC IR, pre- and post-typechecking. +// Rather than have a full class hierarchy for all TC statements, trees are a +// slight variation of Lisp s-expressions. For instance, the expression a*b+1 +// is represented as: +// (+ (* (ident a) (ident b)) (const 1)) +// Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which +// define stringValue(). Everything else is a Compound object, which has a +// 'kind' that is a token from lexer.h's TokenKind enum. Single-character +// operators like '+' are represented using the character itself (so, add.kind() +// would be '+'). Each Compound object also contains a list of subtrees and is +// associated with a SourceRange for error reporting. +// Memory management of trees is done using intrusive_ptr. + +struct Tree; +using TreeRef = c10::intrusive_ptr; +using TreeList = at::SmallVector; + +struct Tree : c10::intrusive_ptr_target { + Tree(int kind_) : kind_(kind_) {} + int kind() const { + return kind_; + } + virtual bool isAtom() const { + return true; + } + virtual const SourceRange& range() const { + throw std::runtime_error("is an Atom"); + } + virtual const std::string& stringValue() const { + throw std::runtime_error("stringValue can only be called on TK_STRING"); + } + virtual const TreeList& trees() const { + static const TreeList empty_trees = {}; + return empty_trees; + } + const TreeRef& tree(size_t i) const { + return trees().at(i); + } + virtual TreeRef map(const std::function& fn) { + (void)fn; + c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer + // from a raw `this` pointer + // so we need to bump the refcount + // to account for this ownership + return TreeRef::reclaim(this); + } + template + void match(int k, Args&... args) const { + matchD(k, "unknown", 0, args...); + } + template + void matchD(int k, const char* filename, int lineno, Args&... args) const { + std::initializer_list vars = {args...}; + matchNumSubtreesD(k, filename, lineno, vars.size(), true); + size_t i = 0; + for (TreeRef* v : vars) { + *v = trees()[i++]; + } + } + void matchNumSubtrees(int k, size_t expected_subtrees) { + return matchNumSubtreesD(k, "unknown", 0, expected_subtrees, false); + } + void matchNumSubtreesD( + int k, + const char* filename, + int lineno, + size_t expected_subtrees, + bool allow_more) const { + if (kind() != k) { + std::stringstream ss; + ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k) + << "' but found '" << kindToString(kind()) << "'\n"; + range().highlight(ss); + throw std::runtime_error(ss.str()); + } + if (trees().size() < expected_subtrees || + (!allow_more && trees().size() != expected_subtrees)) { + std::stringstream ss; + ss << filename << ":" << lineno << ": expected at least " + << expected_subtrees << " subtrees, but found only " << trees().size() + << "\n"; + range().highlight(ss); + throw std::runtime_error(ss.str()); + } + } + ~Tree() override = default; + + private: + int kind_; +}; + +struct String : public Tree { + String(std::string value) : Tree(TK_STRING), value_(std::move(value)) {} + const std::string& stringValue() const override { + return value_; + } + template + static TreeRef create(Args&&... args) { + return c10::make_intrusive(std::forward(args)...); + } + + private: + std::string value_; +}; + +static SourceRange mergeRanges(SourceRange c, const TreeList& others) { + for (const auto& t : others) { + if (t->isAtom()) + continue; + size_t s = std::min(c.start(), t->range().start()); + size_t e = std::max(c.end(), t->range().end()); + c = SourceRange(c.source(), s, e); + } + return c; +} + +struct Compound : public Tree { + Compound(int kind, SourceRange range) + : Tree(kind), range_(std::move(range)) {} + Compound(int kind, const SourceRange& range_, TreeList&& trees_) + : Tree(kind), + range_(mergeRanges(range_, trees_)), + trees_(std::move(trees_)) {} + const TreeList& trees() const override { + return trees_; + } + static TreeRef create( + int kind, + const SourceRange& range_, + TreeList&& trees_) { + return c10::make_intrusive(kind, range_, std::move(trees_)); + } + bool isAtom() const override { + return false; + } + TreeRef map(const std::function& fn) override { + TreeList ret; + for (auto& t : trees()) { + ret.push_back(fn(t)); + } + return Compound::create(kind(), range(), std::move(ret)); + } + + const SourceRange& range() const override { + return range_; + } + + private: + SourceRange range_; + TreeList trees_; +}; + +// tree pretty printer +struct pretty_tree { + pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {} + const TreeRef& tree; + size_t col; + std::unordered_map flat_strings; + const std::string& get_flat(const TreeRef& t) { + auto it = flat_strings.find(t); + if (it != flat_strings.end()) + return it->second; + + std::stringstream out; + switch (t->kind()) { + case TK_STRING: + out << t->stringValue(); + break; + default: + out << "(" << kindToString(t->kind()); + for (const auto& e : t->trees()) { + out << " " << get_flat(e); + } + out << ")"; + break; + } + auto it_ = flat_strings.emplace(t, out.str()); + return it_.first->second; + } + void print(std::ostream& out, const TreeRef& t, int indent) { + const std::string& s = get_flat(t); + if (indent + s.size() < col || t->isAtom()) { + out << s; + return; + } + std::string k = kindToString(t->kind()); + out << "(" << k; + for (const auto& e : t->trees()) { + out << "\n" << std::string(indent + 2, ' '); + print(out, e, indent + 2); + } + out << ")"; + } +}; + +static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) { + t_.print(out, t_.tree, 0); + return out << '\n'; +} + +static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) { + return out << pretty_tree(t); +} + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tree_views.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tree_views.h new file mode 100644 index 0000000000000000000000000000000000000000..f0850e86886dca5818f2991262ae5df1b3706428 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/tree_views.h @@ -0,0 +1,1280 @@ +#pragma once +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::jit { + +// clang-format off +// TreeView provides a statically-typed way to traverse the tree, which should +// be formed according to the grammar below. +// +// A few notes on types and their aliases: +// - List is really a Tree with kind TK_LIST and elements as subtrees +// - Maybe is really a Tree with kind TK_OPTION that has 0 or 1 subtree of type T +// - Builtin types are: Ident (TK_IDENT), String (TK_STRING) +// +// Param = Param(Maybe type, Ident name) TK_PARAM +// +// Decl = Decl(List params, Maybe return_type) TK_DECL +// Def = Def(Ident name, Decl decl, List body) TK_DEF +// ClassDef = ClassDef(Ident name, TK_CLASS_DEF +// Maybe superclass, +// List body) +// +// Stmt = If(Expr cond, List true_body, List false_body) TK_IF +// | For(List targets, List iters, List body) TK_FOR +// | While(Expr cond, List body) TK_WHILE +// | Global(List idents) TK_GLOBAL +// -- NB: the only type of Expr's allowed on lhs are Var +// Or a tuple containing Var with an optional terminating Starred +// | Assign(Expr lhs, Maybe rhs, Maybe type) TK_ASSIGN +// | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs) TK_AUG_ASSIGN +// | Return(List values) TK_RETURN +// | ExprStmt(List expr) TK_EXPR_STMT +// | Raise(Expr expr) TK_RAISE +// | Def TK_DEF +// | With(List targets, List body) TK_WITH +// +// Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR +// | BinOp(Expr lhs, Expr rhs) +// | And TK_AND +// | Or TK_OR +// | Lt '<' +// | Gt '>' +// | Eq TK_EQ +// | Le TK_LE +// | Ge TK_GE +// | Ne TK_NE +// | Is TK_IS +// | IsNot TK_ISNOT +// | Add '+' +// | Sub '-' +// | Mul '*' +// | Div '/' +// | Mod '%' +// | MatMult '@' +// | Pow TK_POW +// | UnaryOp(Expr expr) +// | Not TK_NOT +// | USub '-' +// | Const(String value) TK_CONST +// -- NB: x.name(y) is desugared into name(x, y) +// | Apply(Ident name, List args, List kwargs) TK_APPLY +// | Select(Expr value, Ident selector) '.' +// | Subscript(Expr value, List subscript_exprs) TK_SUBSCRIPT +// | SliceExpr(Maybe start, Maybe end) TK_SLICE_EXPR +// | Var(Ident name) TK_VAR +// | ListLiteral(List inputs) TK_LIST_LITERAL +// | TupleLiteral(List inputs) TK_TUPLE_LITERAL +// | Starred(Expr expr) TK_STARRED +// | WithItem(Expr target, Maybe var) TK_WITH_ITEM +// -- NB: only allowed expressions are Const or List(Const) +// (List as a value, not type constructor) +// Attribute = Attribute(Ident name, Expr value) TK_ATTRIBUTE +// +// AugAssignKind = +// | Add() TK_PLUS_EQ +// | Sub() TK_MINUS_EQ +// | Mul() TK_TIMES_EQ +// | Div() TK_DIV_EQ +// | Mod() TK_MOD_EQ +// + +// Each subclass of TreeView should provide: +// 1. Constructor that takes a TreeRef, and checks that it's of the right type. +// 2. Accessors that get underlying information out of the object. If they +// return subtrees, they should wrap them in appropriate views too. +// 3. Static method 'create' that creates the underlying TreeRef object +// for every TreeRef kind that has a TreeView, the parser always uses +// (e.g.) Ident::create rather than Compound::Create, this means that +// changes to the structure of Ident are always made right here rather +// than both in the parser and in this code. +// XXX: these structs should have no fields to prevent slicing when passing by value +// clang-format on +struct TreeView { + explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {} + TreeRef tree() const { + return tree_; + } + const SourceRange& range() const { + return tree_->range(); + } + operator TreeRef() const { + return tree_; + } + const TreeRef& get() const { + return tree_; + } + int kind() const { + return tree_->kind(); + } + void dump() const { + std::cout << tree_; + } + + protected: + const TreeRef& subtree(size_t i) const { + return tree_->trees().at(i); + } + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + TreeRef tree_; +}; + +template +struct ListIterator { + ListIterator(TreeList::const_iterator it) : it(it) {} + bool operator!=(const ListIterator& rhs) const { + return it != rhs.it; + } + bool operator==(const ListIterator& rhs) const { + return it == rhs.it; + } + T operator*() const { + return T(*it); + } + ListIterator& operator+=(std::ptrdiff_t n) { + it += n; + return *this; + } + ListIterator& operator++() { + ++it; + return *this; + } + ListIterator& operator--() { + --it; + return *this; + } + + private: + TreeList::const_iterator it; +}; + +template +struct List : public TreeView { + using iterator = ListIterator; + using const_iterator = ListIterator; + + List(const TreeRef& tree) : TreeView(tree) { + tree->match(TK_LIST); + // Iterate over list to temporarily instantiate Ts that will check the type + for (const T& elem : *this) { + (void)elem; // silence unused warning + } + } + iterator begin() const { + return iterator(tree_->trees().begin()); + } + iterator end() const { + return iterator(tree_->trees().end()); + } + bool empty() const { + return tree_->trees().begin() == tree_->trees().end(); + } + T operator[](size_t i) const { + return T(subtree(i)); + } + TreeRef map(const std::function& fn) { + return tree_->map([&](TreeRef v) { return fn(T(v)); }); + } + static List create(const SourceRange& range, const std::vector& subtrees) { + TreeList type_erased_sub{subtrees.begin(), subtrees.end()}; + return List(Compound::create(TK_LIST, range, std::move(type_erased_sub))); + } + static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) { + return List(Compound::create(TK_LIST, range, std::move(subtrees))); + } + size_t size() const { + return tree_->trees().size(); + } +}; + +template +struct Maybe : public TreeView { + explicit Maybe(const TreeRef& tree) : TreeView(tree) { + tree_->match(TK_OPTION); + if (tree_->trees().size() > 1) + throw(ErrorReport(tree) << "Maybe trees can have at most one subtree"); + } + /* implicit */ Maybe(const T& tree) : TreeView(tree) {} + bool present() const { + return tree_->trees().size() > 0; + } + T get() const { + return T(tree_->trees().at(0)); + } + TreeRef map(const std::function& fn) { + return tree_->map([&](TreeRef v) { return fn(T(v)); }); + } + static Maybe create(const SourceRange& range) { + return Maybe(Compound::create(TK_OPTION, range, {})); + } + static Maybe create(const SourceRange& range, const T& value) { + return Maybe(Compound::create(TK_OPTION, range, {value})); + } +}; + +struct Ident : public TreeView { + explicit Ident(const TreeRef& tree) : TreeView(tree) { + tree_->match(TK_IDENT); + } + const std::string& name() const { + return subtree(0)->stringValue(); + } + static Ident create(const SourceRange& range, std::string name) { + return Ident( + Compound::create(TK_IDENT, range, {String::create(std::move(name))})); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Base types (production LHS) +//////////////////////////////////////////////////////////////////////////////// + +struct Stmt : public TreeView { + explicit Stmt(const TreeRef& tree) : TreeView(tree) { + switch (tree->kind()) { + case TK_IF: + case TK_FOR: + case TK_WHILE: + case TK_GLOBAL: + case TK_ASSIGN: + case TK_AUG_ASSIGN: + case TK_RETURN: + case TK_EXPR_STMT: + case TK_RAISE: + case TK_ASSERT: + case TK_PASS: + case TK_BREAK: + case TK_DELETE: + case TK_CONTINUE: + case TK_DEF: + case TK_WITH: + return; + default: + throw( + ErrorReport(tree) + << kindToString(tree->kind()) << " is not a valid Stmt"); + } + } +}; + +struct Expr : public TreeView { + explicit Expr(const TreeRef& tree) : TreeView(tree) { + switch (tree->kind()) { + case TK_IF_EXPR: + case TK_AND: + case TK_OR: + case '<': + case '>': + case TK_IS: + case TK_ISNOT: + case TK_EQ: + case TK_LE: + case TK_GE: + case TK_NE: + case '+': + case '-': + case TK_UNARY_MINUS: + case '~': + case '*': + case TK_STARRED: + case '/': + case '%': + case TK_NOT: + case TK_CONST: + case TK_STRINGLITERAL: + case TK_TRUE: + case TK_FALSE: + case TK_NONE: + case TK_NONE_TYPE: + case TK_CAST: + case TK_APPLY: + case '.': + case TK_SUBSCRIPT: + case TK_SLICE_EXPR: + case TK_VAR: + case TK_LIST_LITERAL: + case TK_TUPLE_LITERAL: + case TK_DICT_LITERAL: + case '@': + case TK_POW: + case TK_LSHIFT: + case TK_RSHIFT: + case TK_FLOOR_DIV: + case '&': + case '^': + case '|': + case TK_LIST_COMP: + case TK_DICT_COMP: + case TK_DOTS: + case TK_IN: + case TK_WITH_ITEM: + return; + default: + throw( + ErrorReport(tree) + << kindToString(tree->kind()) << " is not a valid Expr"); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Helper nodes (mostly for function arguments) +//////////////////////////////////////////////////////////////////////////////// + +struct Attribute : public TreeView { + explicit Attribute(const TreeRef& tree) : TreeView(tree) { + tree_->match(TK_ATTRIBUTE); + } + Ident name() const { + return Ident(subtree(0)); + } + Expr value() const { + return Expr(subtree(1)); + } + static Attribute create( + const SourceRange& range, + const Ident& name, + const TreeRef& value) { + return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value})); + } +}; + +struct Param : public TreeView { + explicit Param(const TreeRef& tree) : TreeView(tree) { + tree_->match(TK_PARAM); + } + static Param create( + const SourceRange& range, + const Ident& ident, + const Maybe& type, + const Maybe& def, + bool kwarg_only) { + TreeRef kwarg_only_tree = + Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {}); + return Param(Compound::create( + TK_PARAM, range, {ident, type, def, std::move(kwarg_only_tree)})); + } + Ident ident() const { + return Ident(subtree(0)); + } + Maybe type() const { + return Maybe(subtree(1)); + } + Maybe defaultValue() const { + return Maybe(subtree(2)); + } + bool kwarg_only() const { + return TK_TRUE == subtree(3)->kind(); + } + Param withType(const Maybe& typ) const { + return Param::create(range(), ident(), typ, defaultValue(), kwarg_only()); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Top level definitions +//////////////////////////////////////////////////////////////////////////////// + +struct Decl : public TreeView { + explicit Decl(const TreeRef& tree) : TreeView(tree) { + tree->match(TK_DECL); + } + List params() const { + return List(subtree(0)); + } + Maybe return_type() const { + return Maybe(subtree(1)); + } + static Decl create( + const SourceRange& range, + const List& params, + const Maybe& return_type) { + return Decl(Compound::create(TK_DECL, range, {params, return_type})); + } +}; + +struct Def : public TreeView { + explicit Def(const TreeRef& tree) : TreeView(tree) { + tree->match(TK_DEF); + } + Def withName(std::string new_name) const { + auto new_ident = Ident::create(name().range(), std::move(new_name)); + return create(range(), new_ident, decl(), statements()); + } + Def withDecl(const Decl& decl) const { + return create(range(), name(), decl, statements()); + } + Ident name() const { + return Ident(subtree(0)); + } + Decl decl() const { + return Decl(subtree(1)); + } + List statements() const { + return List(subtree(2)); + } + static Def create( + const SourceRange& range, + const Ident& name, + const Decl& decl, + const List& stmts) { + return Def(Compound::create(TK_DEF, range, {name, decl, stmts})); + } +}; + +// Property represents a named attribute combined with a getter and setter +// method to access and mutate that attribute. +struct Property : public TreeView { + explicit Property(const TreeRef& tree) : TreeView(tree) { + tree->match(TK_PROP); + } + Ident name() const { + return Ident(subtree(0)); + } + Def getter() const { + return Def(subtree(1)); + } + Maybe setter() const { + return Maybe(subtree(2)); + } + static Property create( + const SourceRange& range, + const Ident& name, + const Def& getter, + const Maybe& setter) { + return Property(Compound::create(TK_PROP, range, {name, getter, setter})); + } +}; + +struct Assign; + +struct ClassDef : public TreeView { + explicit ClassDef(const TreeRef& tree) : TreeView(tree) { + tree->match(TK_CLASS_DEF); + } + explicit ClassDef(TreeRef&& tree) : TreeView(std::move(tree)) { + tree_->match(TK_CLASS_DEF); + } + ClassDef withName(std::string new_name) const { + auto new_ident = Ident::create(name().range(), std::move(new_name)); + return create(range(), new_ident, superclass(), body()); + } + Ident name() const { + return Ident(subtree(0)); + } + Maybe superclass() const { + return Maybe(subtree(1)); + } + List body() const { + return List(subtree(2)); + } + Maybe> properties() const { + return Maybe>(subtree(3)); + } + Maybe> assigns() const { + return Maybe>(subtree(4)); + } + static ClassDef create( + const SourceRange& range, + const Ident& name, + const Maybe& superclass, + const List& body) { + return ClassDef(Compound::create( + TK_CLASS_DEF, + range, + {name, + superclass, + body, + Maybe>::create(range), + Maybe>::create(range)})); + } + static ClassDef create( + const SourceRange& range, + const Ident& name, + const Maybe& superclass, + const List& body, + const List& properties, + const List& assigns); +}; + +TORCH_API std::vector getUnresolvedClassAttributes( + const ClassDef& def); + +//////////////////////////////////////////////////////////////////////////////// +// Statements +//////////////////////////////////////////////////////////////////////////////// + +struct If : public Stmt { + explicit If(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_IF); + } + Expr cond() const { + return Expr(subtree(0)); + } + List trueBranch() const { + return List(subtree(1)); + } + List falseBranch() const { + return List(subtree(2)); + } + If withNewBranches( + const List& true_branch, + const List& false_branch) const { + return create(range(), cond(), true_branch, false_branch); + } + static If create( + const SourceRange& range, + const Expr& cond, + const List& true_branch, + const List& false_branch) { + return If( + Compound::create(TK_IF, range, {cond, true_branch, false_branch})); + } +}; + +struct While : public Stmt { + explicit While(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_WHILE); + } + Expr cond() const { + return Expr(subtree(0)); + } + List body() const { + return List(subtree(1)); + } + static While create( + const SourceRange& range, + const Expr& cond, + const List& body) { + return While(Compound::create(TK_WHILE, range, {cond, body})); + } +}; + +struct For : public Stmt { + explicit For(const TreeRef& tree) : Stmt(tree) { + tree->match(TK_FOR); + } + List targets() const { + return List(subtree(0)); + } + List itrs() const { + return List(subtree(1)); + } + List body() const { + return List(subtree(2)); + } + static For create( + const SourceRange& range, + const List& targets, + const List& itrs, + const List& body) { + return For(Compound::create(TK_FOR, range, {targets, itrs, body})); + } +}; + +// TODO: supports only single comprehension for now +struct ListComp : public Expr { + explicit ListComp(const TreeRef& tree) : Expr(tree) { + tree->match(TK_LIST_COMP); + } + Expr elt() const { + return Expr(subtree(0)); + } + Expr target() const { + return Expr(subtree(1)); + } + Expr iter() const { + return Expr(subtree(2)); + } + // TODO: no ifs for now + static ListComp create( + const SourceRange& range, + const Expr& elt, + const Expr& target, + const Expr& iter) { + return ListComp(Compound::create(TK_LIST_COMP, range, {elt, target, iter})); + } +}; + +// TODO: supports only single comprehension for now +struct DictComp : public Expr { + explicit DictComp(const TreeRef& tree) : Expr(tree) { + tree->match(TK_DICT_COMP); + } + Expr key() const { + return Expr(subtree(0)); + } + Expr value() const { + return Expr(subtree(1)); + } + Expr target() const { + return Expr(subtree(2)); + } + Expr iter() const { + return Expr(subtree(3)); + } + // TODO: no ifs for now + static DictComp create( + const SourceRange& range, + const Expr& key, + const Expr& value, + const Expr& target, + const Expr& iter) { + return DictComp( + Compound::create(TK_DICT_COMP, range, {key, value, target, iter})); + } +}; + +struct Global : public Stmt { + explicit Global(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_GLOBAL); + } + List names() { + return List(subtree(0)); + } + static Global create(const SourceRange& range, const List& names) { + return Global(Compound::create(TK_GLOBAL, range, {names})); + } +}; + +struct AugAssignKind : public TreeView { + explicit AugAssignKind(const TreeRef& tree) : TreeView(tree) { + switch (tree->kind()) { + case '+': + case '-': + case '*': + case '/': + case '%': + case '|': + case '&': + case '^': + case TK_POW: + case TK_LSHIFT: + case TK_RSHIFT: + return; + default: + throw(ErrorReport(tree) << "is not a valid AugAssignKind"); + } + } +}; + +// Augmented assignment, like "foo += bar" +struct AugAssign : public Stmt { + explicit AugAssign(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_AUG_ASSIGN); + } + static AugAssign create( + const SourceRange& range, + const Expr& lhs, + const AugAssignKind& aug_op, + const Expr& rhs) { + return AugAssign( + Compound::create(TK_AUG_ASSIGN, range, {lhs, aug_op, rhs})); + } + Expr lhs() const { + return Expr(subtree(0)); + } + int aug_op() const { + return subtree(1)->kind(); + } + Expr rhs() const { + return Expr(subtree(2)); + } +}; + +struct Assign : public Stmt { + explicit Assign(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_ASSIGN); + } + static Assign create( + const SourceRange& range, + const List& lhs, + const Maybe& rhs, + const Maybe& type) { + return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs, type})); + } + + List lhs_list() const { + return List(subtree(0)); + } + + Expr lhs() const { + const auto& li = lhs_list(); + TORCH_INTERNAL_ASSERT(li.size() == 1); + return *li.begin(); + } + + Maybe rhs() const { + return Maybe(subtree(1)); + } + + Maybe type() const { + return Maybe(subtree(2)); + } +}; + +struct Return : public Stmt { + explicit Return(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_RETURN); + } + Expr expr() const { + return Expr(subtree(0)); + } + static Return create(const SourceRange& range, const Expr& value) { + return Return(Compound::create(TK_RETURN, range, {value})); + } +}; + +struct Raise : public Stmt { + explicit Raise(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_RAISE); + } + Expr expr() const { + return Expr(subtree(0)); + } + static Raise create(const SourceRange& range, const Expr& expr) { + return Raise(Compound::create(TK_RAISE, range, {expr})); + } +}; + +struct Assert : public Stmt { + explicit Assert(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_ASSERT); + } + Expr test() const { + return Expr(subtree(0)); + } + Maybe msg() const { + return Maybe(subtree(1)); + } + static Assert create( + const SourceRange& range, + const Expr& test, + const Maybe& msg) { + return Assert(Compound::create(TK_ASSERT, range, {test, msg})); + } +}; + +struct Pass : public Stmt { + explicit Pass(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_PASS); + } + static Pass create(const SourceRange& range) { + return Pass(Compound::create(TK_PASS, range, {})); + } +}; + +struct Dots : public Expr { + explicit Dots(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_DOTS); + } + static Dots create(const SourceRange& range) { + return Dots(Compound::create(TK_DOTS, range, {})); + } +}; + +struct Break : public Stmt { + explicit Break(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_BREAK); + } + static Break create(const SourceRange& range) { + return Break(Compound::create(TK_BREAK, range, {})); + } +}; + +struct Continue : public Stmt { + explicit Continue(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_CONTINUE); + } + static Continue create(const SourceRange& range) { + return Continue(Compound::create(TK_CONTINUE, range, {})); + } +}; + +struct ExprStmt : public Stmt { + explicit ExprStmt(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_EXPR_STMT); + } + Expr expr() { + return Expr(subtree(0)); + } + static ExprStmt create(const SourceRange& range, const Expr& list) { + return ExprStmt(Compound::create(TK_EXPR_STMT, range, {list})); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Expressions +//////////////////////////////////////////////////////////////////////////////// + +struct BinOp : public Expr { + explicit BinOp(const TreeRef& tree) : Expr(tree) { + switch (tree->kind()) { + case TK_AND: + case TK_OR: + case '<': + case '>': + case TK_IS: + case TK_ISNOT: + case TK_EQ: + case TK_LE: + case TK_GE: + case TK_NE: + case '+': + case '*': + case '/': + case '-': + case '@': + case TK_POW: + case TK_LSHIFT: + case TK_RSHIFT: + case '%': + case '&': + case '^': + case '|': + case TK_FLOOR_DIV: + case TK_IN: + if (tree->trees().size() != 2) + throw( + ErrorReport(tree) + << "BinOp expected 2 subtrees, found " << tree->trees().size()); + return; + default: + throw( + ErrorReport(tree) + << kindToString(tree->kind()) << " is not a valid BinOp"); + } + } + Expr lhs() const { + return Expr(subtree(0)); + } + Expr rhs() const { + return Expr(subtree(1)); + } + static BinOp create( + const SourceRange& range, + int kind, + const Expr& lhs, + const Expr& rhs) { + return BinOp(Compound::create(kind, range, {lhs, rhs})); + } +}; + +struct UnaryOp : public Expr { + explicit UnaryOp(const TreeRef& tree) : Expr(tree) { + switch (tree->kind()) { + case TK_UNARY_MINUS: + case '~': + case TK_NOT: + if (tree->trees().size() != 1) + throw( + ErrorReport(tree) + << "UnaryOp expected 1 subtree, found " << tree->trees().size()); + return; + default: + throw( + ErrorReport(tree) + << kindToString(tree->kind()) << " is not a valid UnaryOp"); + } + } + static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) { + return UnaryOp(Compound::create(kind, range, {expr})); + } +}; + +struct Const : public Expr { + explicit Const(const TreeRef& tree) : Expr(tree) { + tree_->matchNumSubtrees(TK_CONST, 1); + } + bool isFloatingPoint() const { + if (isComplex()) + return false; + + bool is_inf = subtree(0)->stringValue() == "inf"; + return is_inf || + subtree(0)->stringValue().find_first_of(".eE") != std::string::npos; + } + bool isIntegral() const { + return !isFloatingPoint() && !isComplex(); + } + bool isComplex() const { + return subtree(0)->stringValue().find_first_of('j') != std::string::npos; + } + int64_t asIntegral() const { + try { + return std::stoll(subtree(0)->stringValue(), nullptr, 0); + } catch (const std::out_of_range&) { + throw( + ErrorReport(range()) << "Integral constant out of range " + "(must fit in a signed 64 bit integer)"); + } + } + double asFloatingPoint() const { + // We can't pass in nullptr as the dummy pointer gets dereferenced for + // Android version of strtod_c(). + char* dummy = nullptr; + return torch::jit::strtod_c(subtree(0)->stringValue().c_str(), &dummy); + } + c10::complex asComplex() const { + char* dummy = nullptr; + auto str = subtree(0)->stringValue(); + // Complex numbers (a+bj, where a is non-zero) are parsed as an addition + // between float/int a and a complex number "bj". When a is 0, a complex + // number bj is created as above. So, while parsing the string, we don't + // have to worry about the real component of the complex number. + auto imag = + torch::jit::strtod_c(str.substr(0, str.size() - 1).c_str(), &dummy); + return c10::complex(0, imag); + } + const std::string& text() const { + return subtree(0)->stringValue(); + } + static Const create(const SourceRange& range, const std::string& value) { + return Const(Compound::create(TK_CONST, range, {String::create(value)})); + } +}; + +struct StringLiteral : public Expr { + explicit StringLiteral(const TreeRef& tree) : Expr(tree) { + tree_->matchNumSubtrees(TK_STRINGLITERAL, 1); + } + const std::string& text() const { + return subtree(0)->stringValue(); + } + static StringLiteral create( + const SourceRange& range, + const std::string& value) { + return StringLiteral( + Compound::create(TK_STRINGLITERAL, range, {String::create(value)})); + } +}; + +struct Apply : public Expr { + explicit Apply(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_APPLY); + } + Expr callee() const { + return Expr(subtree(0)); + } + List inputs() const { + return List(subtree(1)); + } + List attributes() const { + return List(subtree(2)); + } + static Apply create( + const SourceRange& range, + const Expr& callee, + const List& inputs, + const List& attributes) { + return Apply( + Compound::create(TK_APPLY, range, {callee, inputs, attributes})); + } +}; + +struct Select : public Expr { + explicit Select(const TreeRef& tree) : Expr(tree) { + tree_->match('.'); + } + Expr value() const { + return Expr(subtree(0)); + } + Ident selector() const { + return Ident(subtree(1)); + } + static Select create( + const SourceRange& range, + const Expr& value, + const Ident& selector) { + return Select(Compound::create('.', range, {value, selector})); + } +}; + +struct SliceExpr : public Expr { + explicit SliceExpr(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_SLICE_EXPR); + } + Maybe start() const { + return Maybe(subtree(0)); + } + Maybe end() const { + return Maybe(subtree(1)); + } + Maybe step() const { + return Maybe(subtree(2)); + } + Expr startOr(int64_t alternative) const { + const auto startOption = start(); + return startOption.present() ? startOption.get() : createInt(alternative); + } + Expr endOr(int64_t alternative) const { + const auto endOption = end(); + return endOption.present() ? endOption.get() : createInt(alternative); + } + Expr stepOr(int64_t alternative) const { + const auto stepOption = step(); + return stepOption.present() ? stepOption.get() : createInt(alternative); + } + static SliceExpr create( + const SourceRange& range, + const Maybe& start, + const Maybe& end, + const Maybe& step) { + return SliceExpr( + Compound::create(TK_SLICE_EXPR, range, {start, end, step})); + } + + private: + Expr createInt(int64_t value) const { + return Expr(Const::create(range(), std::to_string(value))); + } +}; + +struct Subscript : public Expr { + explicit Subscript(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_SUBSCRIPT); + } + Expr value() const { + return Expr(subtree(0)); + } + List subscript_exprs() const { + return List(subtree(1)); + } + static Subscript create( + const SourceRange& range, + const Expr& value, + const List& subscript_exprs) { + auto whole_range = SourceRange( + range.source(), range.start(), subscript_exprs.range().end() + 1); + return Subscript( + Compound::create(TK_SUBSCRIPT, whole_range, {value, subscript_exprs})); + } +}; + +struct Var : public Expr { + explicit Var(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_VAR); + } + Ident name() const { + return Ident(subtree(0)); + } + static Var create(const SourceRange& range, const Ident& name) { + return Var(Compound::create(TK_VAR, range, {name})); + } +}; + +// WithItem represents an item using with a WithStmt. +struct WithItem : public Expr { + explicit WithItem(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_WITH_ITEM); + } + + Expr target() const { + return Expr(subtree(0)); + } + + Maybe var() const { + return Maybe(subtree(1)); + } + + static WithItem create( + const SourceRange& range, + const Expr& target, + const Maybe& var) { + return WithItem(Compound::create(TK_WITH_ITEM, range, {target, var})); + } +}; + +// With represents a with statement consisting of a list of with items and a +// body of statements. +struct With : public Stmt { + explicit With(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_WITH); + } + + List targets() const { + return List(subtree(0)); + } + + List body() const { + return List(subtree(1)); + } + + static With create( + const SourceRange& range, + const List& targets, + const List& body) { + return With(Compound::create(TK_WITH, range, {targets, body})); + } +}; + +struct TernaryIf : public Expr { + explicit TernaryIf(const TreeRef& tree) : Expr(tree) { + tree_->matchNumSubtrees(TK_IF_EXPR, 3); + } + Expr cond() const { + return Expr(subtree(0)); + } + Expr true_expr() const { + return Expr(subtree(1)); + } + Expr false_expr() const { + return Expr(subtree(2)); + } + static TernaryIf create( + const SourceRange& range, + const Expr& cond, + const Expr& true_expr, + const Expr& false_expr) { + return TernaryIf( + Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr})); + } +}; + +struct ListLiteral : public Expr { + explicit ListLiteral(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_LIST_LITERAL); + } + List inputs() const { + return subtree(0); + } + static ListLiteral create( + const SourceRange& range, + const List& inputs) { + return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs})); + } +}; + +struct TupleLiteral : public Expr { + explicit TupleLiteral(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_TUPLE_LITERAL); + } + List inputs() const { + return subtree(0); + } + static TupleLiteral create( + const SourceRange& range, + const List& inputs) { + return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs})); + } +}; + +struct DictLiteral : public Expr { + explicit DictLiteral(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_DICT_LITERAL); + } + List key_inputs() const { + return subtree(0); + } + List value_inputs() const { + return subtree(1); + } + static DictLiteral create( + const SourceRange& range, + const List& keys, + const List& values) { + return DictLiteral( + Compound::create(TK_DICT_LITERAL, range, {keys, values})); + } +}; + +struct Starred : public Expr { + explicit Starred(const TreeRef& tree) : Expr(tree) { + tree_->match(TK_STARRED); + } + Expr expr() const { + return Expr(subtree(0)); + } + static Starred create(const SourceRange& range, const Expr& expr) { + return Starred(Compound::create(TK_STARRED, range, {expr})); + } +}; + +struct Delete : public Stmt { + explicit Delete(const TreeRef& tree) : Stmt(tree) { + tree_->match(TK_DELETE); + } + List targets() const { + return subtree(0); + } + static Delete create(const SourceRange& range, const List& targets) { + return Delete(Compound::create(TK_DELETE, range, {targets})); + } +}; + +/* + * NOTE: transforming PEP 604 union into equivalent union type + * + * NOTE: Union[int, float] parses into: + * expr:(subscript + * (variable (ident Union)) + * (list + * (variable (ident int)) + * (variable (ident float)))) + * subscript + * + * NOTE: (int | float) parses into: + * expr:(| + * (variable (ident int)) + * (variable (ident float))) + * | + */ + +inline void _flatten_pep604_union( + const torch::jit::Expr& node, + std::vector* result) { + // flatten possibly nested union expressions like (int | (float | str)) + // into a flat list of expressions like [int, float, str] + if (node.kind() == '|') { + auto as_binop = torch::jit::BinOp(node); + _flatten_pep604_union(as_binop.lhs(), result); + _flatten_pep604_union(as_binop.rhs(), result); + } else { + result->push_back(node); + } +} + +inline std::vector get_pep604_union_members(const Expr& node) { + std::vector result; + _flatten_pep604_union(node, &result); + return result; +} + +// Flattens a PEP 604 union into a classical union. +// For example, ((x | y) | z) is transformed into Union[x, y, z]. +inline Expr pep604union_to_union(const Expr& expr) { + // noop if not a pep604 union + if (expr.kind() != '|') + return expr; + + // In order to support unions with more than 2 operands ((x|y)|z), we need to + // recursively flatten the tree of | expressions. + auto members = get_pep604_union_members(expr); + auto synthesised_union = Subscript::create( + expr.range(), + Var::create(expr.range(), Ident::create(expr.range(), "Union")), + List::create(expr.range(), members)); +#if defined(__clang__) + return std::move(synthesised_union); +#else + return synthesised_union; +#endif +} + +} // namespace torch::jit + +namespace std { + +template +struct iterator_traits> + : std::iterator_traits {}; + +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/versioned_symbols.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/versioned_symbols.h new file mode 100644 index 0000000000000000000000000000000000000000..fc6b0ff7a960c4a8fed6d9335f93015f4fb55342 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/frontend/versioned_symbols.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch::jit { +// Maps the given symbol into an implementation of its behavior at the +// given version. +// See note [Versioned Symbols] +TORCH_API Symbol +get_symbol_for_version(const Symbol name, const uint64_t version); + +// Maps the given kind to the minimum version that supports it. +// See note [Dynamic Versions and torch.jit.save vs. torch.save] +TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind); +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/alias_analysis.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/alias_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..f83c96a2da186f4f21c99e3b68403e1b58dcd2fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/alias_analysis.h @@ -0,0 +1,363 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +class ValueAndMemoryLocationSet; + +/** + * Alias analysis pass. + * + * This pass produces an AliasDb that contains aliasing and mutation + * information about the graph. Users can use this information to determine + * whether mutations to the graph are safe, i.e. they don't reorder/change + * nodes in a way that affects output. + * + * Every value with a mutable type (Tensors, Lists, Tuples, etc.) will be + * associated with one or more "alias sets". If two values share an alias set, + * that means they may alias, implying that a mutation to one value cannot be + * reordered past a use of the other. Only reordering two reads of an alias set + * is considered safe. + * + * There is a special alias set called the "wildcard set", which indicates that + * we're not sure what this value may alias. To be conservative, we consider the + * wildcard alias set as potentially aliasing any other wildcard value within + * the same type class. Whenever a value becomes contained by another value, + * such as when a Tensor is appended to a List[Tensor], the contained element + * becomes part of the wildcard set. + * + * Values that contain other mutable types, such as List[Tensor], are + * initialized as containing the Wildcard set for all contained mutable types. + * + * The AliasDb API references the idea of "mutable" vs "immutable" + * types. "Mutable" means that the object's value can change, while + * "immutable" means that the value is fixed. (For example, `List` is + * mutable, so you can add and delete elements from it. On the other + * hand, you can't modify a Tuple once you create it, making `Tuple` an + * immutable container.) + * + * `isFrozen` - if the Module is frozen then consider attributes as freshly + * created objects. Freezing API invokes alias analysis to check if they are + * mutated internally. + * + * `descendFunctionCalls` - recursively analyze function and method calls + * instead of conservative analysis. Generally analysis should be done after + * inlining so the implmentation for recursive analysis is unoptimized. + */ +class AliasDb { + public: + TORCH_API explicit AliasDb( + std::shared_ptr graphi, + bool isFrozen = false, + bool descendFunctionCalls = false); + TORCH_API ~AliasDb(); + + // There are limitations to what effects the alias analysis can track. Two + // kinds of nodes may have untracked effects: + // 1. Nodes that write to a value that may alias the graph inputs (since + // the inputs can be used outside the graph). + // 2. Nodes that write to something in the wildcard set. + // + // These nodes are considered not safe to eliminate or mutate under any + // circumstances. + bool writesToWildcard(Node* n) const; + + // Does `n` write to an alias of one of the values in `vs`? + // if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks + TORCH_API bool writesToAlias(Node* n, const ValueSet& vs) const; + + // Does `n` write to any of the values in `vls`? + TORCH_API bool writesToAlias(Node* n, const ValueAndMemoryLocationSet& vls) + const; + + TORCH_API ValueAndMemoryLocationSet getValueAndMemoryLocationSet() const; + + // Does `a` and `b` potentially share a memory location or do either + // hold in memory any element that exists in the other + TORCH_API bool mayContainAlias(Value* a, Value* b) const; + + TORCH_API bool mayContainAlias(Value* a, const at::ArrayRef b) const; + + // Do any values in group `a` share a memory location or hold in memory + // any element that exists in group `b` + TORCH_API bool mayContainAlias( + const at::ArrayRef a, + const at::ArrayRef b) const; + + // Do `a` and `b` potentially share a memory location? + TORCH_API bool mayAlias(const Value* a, const Value* b) const; + // Do any values in group `a` potentially share a memory location with any + // value in group `b`? i.e. may they overlap? + TORCH_API bool mayAlias(const ValueSet& a, const ValueSet& b) const; + + // Do any nodes write to an alias set input to `n`? + TORCH_API bool hasInputWriters(const Node* n) const; + + // Do any nodes write to an alias set output by `n`? + TORCH_API bool hasOutputWriters(const Node* n) const; + + // Do any nodes write to an alias set inputed/outputed by `n`? + TORCH_API bool hasWriters(const Node* n) const; + + // Do any nodes write to `v`s memory location? + TORCH_API bool hasWriters(const Value* v) const; + + // Is the operation in-place? i.e. doesn't write anywhere but locations it + // reads from. + TORCH_API bool isMutable(Node* n) const; + + TORCH_API bool escapesScope(const at::ArrayRef& vs) const; + + // Is it safe to change whether `a` and `b` alias each other ? + TORCH_API bool safeToChangeAliasingRelationship( + const at::ArrayRef& a, + const at::ArrayRef& b) const; + + // Move `n` (already in the graph) after `movePoint` in the topological order. + // + // Tries to preserve value dependencies, so other nodes might be moved. We + // make two guarantees about the postcondition of the node list: + // - `n` is directly after `movePoint`. + // - only nodes between `n` and `movePoint` have been moved. + // + // Returns `false` if it's impossible to move `n` after `MovePoint` without + // violating dependencies, otherwise executes the move and returns `true` + TORCH_API bool moveAfterTopologicallyValid(Node* n, Node* movePoint); + TORCH_API bool moveBeforeTopologicallyValid(Node* n, Node* movePoint); + + bool couldMoveAfterTopologically(Node* n, Node* movePoint); + bool couldMoveBeforeTopologically(Node* n, Node* movePoint); + + // For debugging: print alias db state to stdout + TORCH_API void dump() const; + TORCH_API std::string toString() const; + + // Generates a DOT (www.graphviz.org) graph representation + // + // Returns `true` if the output file was successfully generated + // + // WARNING: The output dot file path can't include shell specific notations, + // for example you can't use "~/temp/aliasdb.dot" + // (instead, use "/home/user/temp/aliasdb.dot") + // + TORCH_API bool dumpToGraphvizFile(const char* filename) const; + TORCH_API std::string toGraphviz() const; + + // Returns `true` if the given element is mutable or if it is a + // container type with an internal mutable element (e.g. + // `Tuple[int, Tensor]` has an internal mutable type `Tensor`, so + // it would be considered a "mutable type" in AliasDb) + static bool isMutableType(const Value* v); + static bool isMutableType(const TypePtr& type); + + /** + * Mutation API + * + * These methods allow you to update AliasDb in-place if you are performing + * graph mutation. + * + * WARNING: These methods should be considered INTERNAL. They do not perform + * very many correctness checks, the user is responsible for making sure they + * are updating AliasDb correctly. `Lint()`ing the AliasDb can help with + * this. + */ + // Copy `existing`s aliasing info to `new_value`, and remove `existing`. + TORCH_API void replaceWithNewValue(Value* existing, Value* new_value); + // Copy `from`s aliasing info to `to`. + TORCH_API void copyValue(Value* from, Value* to); + // Create a new `value` that does not alias anything else. + TORCH_API void createValue(const Value* value); + + // Enable more precise treatment of prim::TupleConstruct. + void enablePreciseTupleContainerAnalysis(); + + friend struct MutationRemover; + friend class ValueAndMemoryLocationSet; + + private: + // Helper for topologically-safe node moves. + class WorkingSet; + enum class MoveSide { BEFORE, AFTER }; + bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun); + void move(Node* toMove, Node* movePoint, MoveSide moveSide); + bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; + + bool isMutableTypeInternal(const Value* v) const; + bool isMutableTypeInternal(const TypePtr& type) const; + + /** + * Write and read internal API + */ + // Get all the values that `n` writes to. + // NOTE: this only returns values directly written to, not aliases thereof + // + // if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks + MemoryLocations getWrites(Node* n) const; + void getWritesImpl(Node* n, MemoryLocations& ret) const; + // Register the fact that `n` writes to `v`. + void registerWrite(const Value* v, Node* n, bool writeToContained = false); + // Get all the values that `n` reads from. + // if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks + MemoryLocations getReads(Node* n) const; + void getReadsImpl(Node* n, MemoryLocations& ret) const; + MemoryLocations getMemoryLocations(Value* v) const; + + /** + * Wildcard methods + */ + // Register `v` as a wildcard value. + std::optional setWildcard(const Value* v); + + // Is this a value which will not alias? + bool nonAliasingValue(const Value* elem) const; + + /** + * Special analysis methods + */ + void analyze(const std::shared_ptr& graph); + void analyze(Block* block); + void analyze(Node* node); + void analyzeImpl(Node* node); + void analyzeIf(Node* node); + void analyzeLoop(Node* node); + void analyzeSubgraph(Node* node, const std::shared_ptr& subgraph); + void analyzeSubgraph(Node* node); + void analyzeCreator(Node* node); + void analyzeExtractor(Node* node); + void analyzeChunk(Node* node); + void analyzeBroadcastingChunk(Node* node); + void analyzeFork(Node* node); + void analyzeWait(Node* node); + void analyzeAwaitable(Node* node); + void analyzeAwaitableWait(Node* node); + void analyzeRpcAsync(Node* node); + void analyzeBatchNorm(Node* node); + void analyzeInstanceNorm(Node* node); + void analyzeGradOf(Node* node); + void analyzeSetAttr(Node* node); + void analyzeConservative(Node* node); + void analyzeContainerConstruct(Node* node); + bool tryRegisteredAnalysis(Node* node); + + /** + * Alias manipulation methods + */ + void makeAllAlias(const std::vector& values); + void makePointerTo(const Value* value, const Value* to); + TORCH_API void addToContainedElements( + const Value* element, + const Value* container); + void mapAliases(at::ArrayRef to, at::ArrayRef from); + void giveFreshAlias( + const Value* value, + bool add_wildcard_to_contained_elems = true); + Element* getOrCreateElement(const Value* value); + + const AliasTypeSet* mapTypeToAliasTypeSetPtr(const TypePtr& type) const; + bool functionalNonEscapingListUse(const Use& use) const; + bool functionalNonEscapingTupleUse(const Use& use) const; + + std::shared_ptr graph_; + + // If the Module is frozen then consider attributes as freshly created + // objects. Freezing API invokes alias analysis to check if they are mutated + // internally. + bool isFrozen_; + + bool descend_function_calls_; + std::unordered_map>> + function_call_copies_; + + // The points-to graph that stores aliasing relationships + std::unique_ptr memoryDAGBuilder_; + std::unique_ptr memoryDAG_; + + // Mapping of values to MemoryDAG elements + ska::flat_hash_map elementMap_; + // All wildcard Elements (one for each unique mutable type) + ska::flat_hash_map wildcardIndex_; + Element* getWildcard(const TypePtr& type) const; + std::optional tryGetOrCreateWildcard(const TypePtr& type); + void addContainedTypesToFreshElement( + Element* container_elem, + const AliasTypeSet& mut_types); + void pointUnionTypeElementToAllContainedTypes( + Element* container_elem, + const AliasTypeSet& mut_types); + + std::vector getElements(at::ArrayRef vs) const; + bool mayAliasWildcard(const Value* v) const; + bool mayAliasWildcard(const at::ArrayRef vs) const; + bool hasWriters(const at::ArrayRef& values) const; + + // Cached mapping of type ptrs to their mutable types + mutable ska::flat_hash_map mapped_mutable_types_; + + /** + * State for tracking write info. + */ + // Write registry where the analysis can record the writes as it sees them. + // This information is later denormalized into various caches to improve query + // efficiency. + struct WriteRegistry; + std::unique_ptr writeRegistry_; + + // Map of nodes to the memory locations that they write to + using TWriteIndex = ska::flat_hash_map; + std::optional writeIndex_; + // Collection of all memory locations that are written to. + std::optional writtenToLocationsIndex_; + void buildWrittenToLocationsIndex(); + + std::unordered_set wildcards_; + + std::string getElementName(const Element* e) const; + + friend void Lint(const AliasDb* db); +}; + +// Helper check that invariants over AliasDb are maintained. +// Useful if you are using the AliasDb mutation API and want to check you did +// the right thing. +TORCH_API void Lint(const AliasDb* db); + +/** + * ValueAndMemoryLocationSet + * + * A insert-only set of values which also maintains a MemoryLocations bitset + * of the memory locations that the values alias. It is insert-only. It + * should be constructed by calling aliasDb.getValueAndMemoryLocationSet(). + * + * WARNING: + * * The AliasDb must not be mutated after construction of a + * ValueAndMemoryLocationsSet, or else the MemoryLocations stored in the + * ValueAndMemoryLocationSet will no longer be accurate. + * * A ValueAndMemoryLocationsSet is tied to an instsance of AliasDb but + * does not own the AliasDb. It is the user's responsibility to ensure + * that the AliasDb outlives the ValuesAndMemoryLocationsSet. + * + * The use case for this is to be able to implement writesToAlias + * more efficiently for a set of values. + */ +class ValueAndMemoryLocationSet { + public: + TORCH_API void insert(Value* v); + TORCH_API ValueSet& getValueSet(); + + friend class AliasDb; + + private: + ValueAndMemoryLocationSet(const AliasDb* db) : aliasDb_(db) {} + + const AliasDb* aliasDb_; + ValueSet valueSet_; + MemoryLocations memoryLocations_; +}; + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/attributes.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/attributes.h new file mode 100644 index 0000000000000000000000000000000000000000..f6e8f214807831227c8f596fc4071fbc4a2a6518 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/attributes.h @@ -0,0 +1,180 @@ +#pragma once +#include +#include +#include + +#include +#include + +#include + +namespace torch::jit { + +using ::c10::Symbol; + +constexpr int max_tensor_display_size = 10; + +enum class AttributeKind { + f, + fs, + c, + cs, + i, + is, + s, + ss, + t, + ts, + g, + gs, + ty, + tys, + ival +}; +static inline const char* toString(AttributeKind kind) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + static const char* names[] = { + "f", + "c", + "cs", + "fs", + "i", + "is", + "s", + "ss", + "t", + "ts", + "g", + "gs", + "ty", + "tys", + "ival"}; + AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(*names)); + return names[int(kind)]; +} + +struct AttributeValue { + AttributeValue(Symbol name) : name(name) {} + using Ptr = std::unique_ptr; + Symbol name; + virtual AttributeKind kind() const = 0; + virtual Ptr clone() const = 0; + virtual ~AttributeValue() = default; +}; + +template +struct ScalarAttributeValue : public AttributeValue { + using ConstructorType = T; + using ValueType = T; + ScalarAttributeValue(Symbol name, ConstructorType value_) + : AttributeValue(name), value_(std::move(value_)) {} + ValueType& value() { + return value_; + } + Ptr clone() const override { + return Ptr(new ScalarAttributeValue(name, value_)); + } + AttributeKind kind() const override { + return Kind; + } + + private: + ValueType value_; +}; + +template +struct VectorAttributeValue : public AttributeValue { + using ConstructorType = std::vector; + using ValueType = std::vector; + VectorAttributeValue(Symbol name, ConstructorType value_) + : AttributeValue(name), value_(std::move(value_)) {} + ValueType& value() { + return value_; + } + AttributeKind kind() const override { + return Kind; + } + std::unique_ptr clone() const override { + auto copy = value_; + return Ptr(new VectorAttributeValue(name, std::move(copy))); + } + + private: + ValueType value_; +}; + +using ComplexAttr = + ScalarAttributeValue, AttributeKind::c>; +using ComplexValsAttr = + VectorAttributeValue, AttributeKind::cs>; +using FloatAttr = ScalarAttributeValue; +using FloatsAttr = VectorAttributeValue; +using IntAttr = ScalarAttributeValue; +using IntsAttr = VectorAttributeValue; +using StringAttr = ScalarAttributeValue; +using StringsAttr = VectorAttributeValue; +using TensorAttr = ScalarAttributeValue; +using TensorsAttr = VectorAttributeValue; +using TypeAttr = ScalarAttributeValue; +using TypesAttr = VectorAttributeValue; +using IValueAttr = ScalarAttributeValue; + +struct Graph; + +// We special case Graph attributes like this because we want to ensure that +// Graph::copy() is called when we clone() these attributes. +struct TORCH_API GraphAttr : public AttributeValue { + using ConstructorType = std::shared_ptr; + using ValueType = std::shared_ptr; + GraphAttr(Symbol name, ConstructorType value_) + : AttributeValue(name), value_(std::move(value_)) {} + ValueType& value() { + return value_; + } + Ptr clone() const override; + AttributeKind kind() const override { + return AttributeKind::g; + } + + private: + std::shared_ptr value_; +}; + +struct TORCH_API GraphsAttr : public AttributeValue { + using ConstructorType = std::vector>; + using ValueType = std::vector>; + GraphsAttr(Symbol name, ConstructorType value_) + : AttributeValue(name), value_(std::move(value_)) {} + ValueType& value() { + return value_; + } + AttributeKind kind() const override { + return AttributeKind::gs; + } + std::unique_ptr clone() const override; + + private: + ValueType value_; +}; + +struct IRAttributeError : public std::exception { + IRAttributeError(Symbol name, bool defined) { + std::stringstream ss; + // NOLINTNEXTLINE(bugprone-branch-clone) + if (!defined) { + ss << "required keyword attribute '" << name.toUnqualString() + << "' is undefined"; + } else { + ss << "required keyword attribute '" << name.toUnqualString() + << "' has the wrong type"; + } + msg = ss.str(); + } + const char* what() const noexcept override { + return msg.c_str(); + } + + private: + std::string msg; +}; +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/constants.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..43d1205a438d21e7f103f0aaf2b872f2779d890b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/constants.h @@ -0,0 +1,60 @@ +#pragma once +#include +#include +#include +#include +#include + +// helpers for handling constants in the IR +// - create constant nodes from ints, floats, complex, intlist, Tensors, and +// other types +// - implement primitive constant ops. + +namespace torch::jit { + +using ::c10::IValue; + +struct Graph; +struct Value; + +// thrown when insertConstant cannot encode the IValue into a graph +struct TORCH_API constant_not_supported_error : public std::runtime_error { + using runtime_error::runtime_error; +}; + +TORCH_API Value* insertConstant( + Graph& g, + const IValue& val, + std::optional loc = std::nullopt, + std::optional scope = std::nullopt); + +// note: prefer g.insertConsant(val, loc) which does exactly the same thing +// this function is only declared/defined here because its implementation is +// closely related to the implementation of prim::Constant that is also in +// constants.cpp. +// +// returns a std::nullopt if the IValue kind cannot be inserted as a constant +TORCH_API std::optional tryInsertConstant( + Graph& g, + const IValue& val, + std::optional loc = std::nullopt, + std::optional scope = std::nullopt); + +//////////////////////////////////////////////////////////////////////////////// +// Helper for retrieving constants +//////////////////////////////////////////////////////////////////////////////// + +// attempt to convert a (possibly constant) Value* into an interpreter value +// (IValue). returns std::nullopt if the Value* was not constant +TORCH_API std::optional toIValue(const Value* v); + +// if a value is a constant then try to turn into type T using the +// same rules as the interpreter +template +std::optional constant_as(const Value* v) { + if (auto ivalue = toIValue(v)) { + return ivalue->to(); + } + return std::nullopt; +} +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/graph_node_list.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/graph_node_list.h new file mode 100644 index 0000000000000000000000000000000000000000..aeed2380e8e8825c62da594f89e6467269db8354 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/graph_node_list.h @@ -0,0 +1,199 @@ +#pragma once + +#include + +namespace torch::jit { + +// Intrusive doubly linked lists with sane reverse iterators. +// The header file is named generic_graph_node_list.h because it is ONLY +// used for Graph's Node lists, and if you want to use it for other +// things, you will have to do some refactoring. +// +// At the moment, the templated type T must support a few operations: +// +// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr }; +// which are used for the intrusive linked list pointers. +// +// - It must have a method 'destroy()', which removes T from the +// list and frees a T. +// +// In practice, we are only using it with Node and const Node. 'destroy()' +// needs to be renegotiated if you want to use this somewhere else. +// +// Regardless of the iteration direction, iterators always physically point +// to the element they logically point to, rather than +// the off-by-one behavior for all standard library reverse iterators like +// std::list. + +// The list is includes two sentinel nodes, one at the beginning and one at the +// end with a circular link between them. It is an error to insert nodes after +// the end sentinel node but before the beginning node: + +// Visualization showing only the next() links: +// HEAD -> first -> second -> ... -> last -> TAIL +// ^------------------------------------------ + +// Visualization showing only the prev() links: +// HEAD <- first <- second <- ... <- last <- TAIL +// ------------------------------------------^ + +static constexpr int kNextDirection = 0; +static constexpr int kPrevDirection = 1; + +template +struct generic_graph_node_list; + +template +struct generic_graph_node_list_iterator; + +struct Node; +using graph_node_list = generic_graph_node_list; +using const_graph_node_list = generic_graph_node_list; +using graph_node_list_iterator = generic_graph_node_list_iterator; +using const_graph_node_list_iterator = + generic_graph_node_list_iterator; + +template +struct generic_graph_node_list_iterator { + generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {} + generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {} + generic_graph_node_list_iterator( + const generic_graph_node_list_iterator& rhs) = default; + generic_graph_node_list_iterator( + generic_graph_node_list_iterator&& rhs) noexcept = default; + generic_graph_node_list_iterator& operator=( + const generic_graph_node_list_iterator& rhs) = default; + generic_graph_node_list_iterator& operator=( + generic_graph_node_list_iterator&& rhs) noexcept = default; + T* operator*() const { + return cur; + } + T* operator->() const { + return cur; + } + generic_graph_node_list_iterator& operator++() { + AT_ASSERT(cur); + cur = cur->next_in_graph[d]; + return *this; + } + generic_graph_node_list_iterator operator++(int) { + generic_graph_node_list_iterator old = *this; + ++(*this); + return old; + } + generic_graph_node_list_iterator& operator--() { + AT_ASSERT(cur); + cur = cur->next_in_graph[reverseDir()]; + return *this; + } + generic_graph_node_list_iterator operator--(int) { + generic_graph_node_list_iterator old = *this; + --(*this); + return old; + } + + // erase cur without invalidating this iterator + // named differently from destroy so that ->/. bugs do not + // silently cause the wrong one to be called. + // iterator will point to the previous entry after call + void destroyCurrent() { + T* n = cur; + cur = cur->next_in_graph[reverseDir()]; + n->destroy(); + } + generic_graph_node_list_iterator reverse() { + return generic_graph_node_list_iterator(cur, reverseDir()); + } + + private: + int reverseDir() { + return d == kNextDirection ? kPrevDirection : kNextDirection; + } + T* cur; + int d; // direction 0 is forward 1 is reverse, see next_in_graph +}; + +template +struct generic_graph_node_list { + using iterator = generic_graph_node_list_iterator; + using const_iterator = generic_graph_node_list_iterator; + generic_graph_node_list_iterator begin() { + return generic_graph_node_list_iterator(head->next_in_graph[d], d); + } + generic_graph_node_list_iterator begin() const { + return generic_graph_node_list_iterator(head->next_in_graph[d], d); + } + generic_graph_node_list_iterator end() { + return generic_graph_node_list_iterator(head->next_in_graph[!d], d); + } + generic_graph_node_list_iterator end() const { + return generic_graph_node_list_iterator( + head->next_in_graph[!d], d); + } + generic_graph_node_list_iterator rbegin() { + return reverse().begin(); + } + generic_graph_node_list_iterator rbegin() const { + return reverse().begin(); + } + generic_graph_node_list_iterator rend() { + return reverse().end(); + } + generic_graph_node_list_iterator rend() const { + return reverse().end(); + } + generic_graph_node_list reverse() { + return generic_graph_node_list(head->next_in_graph[!d], !d); + } + const generic_graph_node_list reverse() const { + return generic_graph_node_list(head->next_in_graph[!d], !d); + } + T* front() { + return head->next_in_graph[d]; + } + const T* front() const { + return head->next_in_graph[d]; + } + T* back() { + return head->next_in_graph[!d]; + } + const T* back() const { + return head->next_in_graph[!d]; + } + generic_graph_node_list(T* head, int d) : head(head), d(d) {} + + private: + T* head; // both head and tail are sentinel nodes + // the first real node is head->next_in_graph[d] + // the tail sentinel is head->next_in_graph[!d] + int d; +}; + +template +static inline bool operator==( + generic_graph_node_list_iterator a, + generic_graph_node_list_iterator b) { + return *a == *b; +} + +template +static inline bool operator!=( + generic_graph_node_list_iterator a, + generic_graph_node_list_iterator b) { + return *a != *b; +} + +} // namespace torch::jit + +namespace std { + +template +struct iterator_traits> { + using difference_type = int64_t; + using value_type = T*; + using pointer = T**; + using reference = T*&; + using iterator_category = bidirectional_iterator_tag; +}; + +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/graph_utils.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/graph_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1c5e702e5900d4bf0fb41c9cfe1aebb6d0f8f1ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/graph_utils.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +#include + +namespace torch::jit { + +TORCH_API TypePtr getTensorType(const at::Tensor& t, bool complete); + +TORCH_API TypePtr inferShapeAndTypeForInput( + TypePtr input_type, + Stack::const_iterator& s_iter, + const Stack::const_iterator& s_iter_end, + bool complete); + +TORCH_API void setInputTensorTypes( + Graph& g, + const Stack& stack, + bool complete, + const std::vector& param_count_list = {}); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/ir.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/ir.h new file mode 100644 index 0000000000000000000000000000000000000000..fc780c26c3dd903d5c6a0c36d6cc95a4fe8d4ece --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/ir/ir.h @@ -0,0 +1,1833 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Forward declare, the real meat is in python_ir.cpp +template +class THPPointer; +using THPObjectPtr = THPPointer; +using pyobj_list = std::vector; + +namespace torch::jit { +namespace utils { +TORCH_API std::string getNodesModuleHierarchy(const Node& n); +} // namespace utils +class AliasDb; + +using ::c10::Argument; +using ::c10::FunctionSchema; +using ::c10::Symbol; + +using ::c10::ivalue::Shared; + +using ::c10::IValue; +using ::c10::ivalue::Future; + +using ::c10::ivalue::ConstantString; + +#define C10_USING(T) using ::c10::T; +C10_FORALL_TYPES(C10_USING) +#undef C10_USING + +#define C10_USING(T) using ::c10::T##Ptr; +C10_FORALL_TYPES(C10_USING) +#undef C10_USING + +using ::c10::Type; +using ::c10::TypeEnv; +using ::c10::TypePtr; + +using ::c10::getTypePtr; +using ::c10::MatchTypeReturn; +using ::c10::TypeKind; + +using ::c10::fmap; + +namespace prim { +using namespace ::c10::prim; +} +namespace attr { +using namespace ::c10::attr; +} +namespace aten { +using namespace ::c10::aten; +} +namespace cuda { +#if !defined(USE_ROCM) +using namespace ::c10::cuda; +#endif +} // namespace cuda + +struct Function; +struct GraphFunction; +struct MatchedSchema; + +// A Graph represents one "function" of computation. +// It uses a simple ownership model where the graph owns all the nodes inside +// it. All references inside the graph are raw pointers. Destroying the Graph +// will invalidate any pointers to nodes in the graph. +struct Graph; + +// Node is the base class of the IR graph. It represents one computation +// and dependencies on a list of Values. The "prim-ops", so to speak. +struct Node; + +// A Value represents an input or output to node that is either a +// Tensor or an opaque Handle object, as determined by type(). +struct Value; + +TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g); +TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n); + +// A list of nodes, with inputs and outputs +struct Block; + +// Each use is represented by this type, see 'Node::uses()' +// 'user' is the consumer of the value, 'offset' is the index into +// 'user's input this where the producers will be found. +struct Use { + Use(Node* user, size_t offset) : user(user), offset(offset) {} + Node* user; + size_t offset; + + bool operator==(const Use& b) { + return user == b.user && offset == b.offset; + } +}; + +// Note [User node does not uniquely identify use] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// A while back, we wrote some code manipulating uses that looked like this: +// +// for (auto& use : used_val->uses_) { +// if (use.user == this_node) { +// use.offset += 1; +// break; +// } +// } +// +// This code is trying to find a particular use (our node's use) to update it. +// However, it's wrong: there may be *multiple* uses of a value %x in a node, +// as might be the case in this IR: +// +// %y = Add %x %x +// +// In this case, there are two uses of %x whose user is the node 'Add %x %x'. +// So, "use induced by this node" is not a well-formed concept. +// +// If you are looking for "use induced by an input", it's best to use +// findUseForInput() to get it. + +// the list types are intentionally simple, but we type-def +// them here so if we need to change them, refactoring will be easier +using node_list = std::vector; +using value_list = std::vector; +using use_list = std::vector; +template +using ArrayRef = at::ArrayRef; +using NodeKind = Symbol; +using topo_position_t = int64_t; +using ValueSet = std::unordered_set; + +struct OperatorSet; +template +struct OperatorMap; + +// This is a wrapper to allow invalidating the Python object +// safely when the C++ object for a Node/Value/Block is deleted +// like much of graph, it isn't safe for different threads to +// access the same graph +template +struct Wrap { + explicit Wrap(T* p) : elem(p) {} + void clear() { + if (clear_cb) { + clear_cb(elem); + } + elem = nullptr; + } + T* elem; + void (*clear_cb)(void*){nullptr}; +}; + +struct Value { + AT_DISALLOW_COPY_AND_ASSIGN(Value); + Value(Node* node_, size_t offset_); + + private: + friend struct Node; + friend struct Graph; + Node* node_; + size_t offset_; + size_t unique_ = 0; // unique id + use_list uses_; + std::string unique_name_; + TypePtr type_; + // a managing wrapper for Python to allow invalidation + std::shared_ptr> wrap_; + + public: + Value* setType(TypePtr type); + TORCH_API void inferTypeFrom(const at::Tensor& output); + TORCH_API void inferTypeFrom( + const c10::intrusive_ptr& output); + const TypePtr& type() const { + AT_ASSERT(type_ != nullptr); + return type_; + } + bool requires_grad() const { + return type()->requires_grad(); + } + bool isCompleteTensor() const { + if (auto pt = type()->cast()) { + return pt->isComplete(); + } + return false; + } + TORCH_API bool mustBeNone() const; + TORCH_API bool mustNotBeNone() const; + size_t unique() const { + return unique_; + } + bool hasDebugName() const { + return !unique_name_.empty(); + } + static bool isValidName(const std::string& name); + TORCH_API Value* setDebugName(const std::string& name); + std::string debugName() const { + if (hasDebugName()) { + return unique_name_; + } + return std::to_string(unique()); + } + TORCH_API std::string debugNameBase() const; + Node* node() { + return node_; + } + size_t offset() const { + return offset_; + } + void setOffset(size_t offset) { + offset_ = offset; + } + const Node* node() const { + return node_; + } + + /** + * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. + * Check #87343 for details. + */ + Graph* owningGraph(); + const Graph* owningGraph() const; + // TODO: make this more const correct + const use_list& uses() const { + return uses_; + } + + bool hasUses() const { + return !uses().empty(); + } + + TORCH_API void replaceFirstUseWith(Value* newValue); + + // Replaces all uses of this value with 'newValue'. + // + // Given: %3 = f(%1, %2) + // %4 = g(%3) + // %5 = h(%3, %3) + // Execute: %3.replaceAllUsesWith(%6) + // Result: %3 = f(%1, %2) + // %4 = g(%6) + // %5 = h(%6, %6) + TORCH_API void replaceAllUsesWith(Value* newValue); + + // Replaces all uses of this value with 'newValue' after 'node'. + // Given: %3 = f(%1, %2) + // %4 = g(%3) + // %5 = inplace_(%3) + // %6 = h(%3, %3) + // Execute: %3.replaceAllUsesAfterNodeWith(%5.node(), %5) + // Result: %3 = f(%1, %2) + // %4 = g(%3) + // %5 = inplace_(%3) + // %6 = h(%5, %5) + // XXX: does not check scoping legality, consider using + // replaceAllUsesDominatedByNodeWith + TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue); + + // Replaces all uses of this value with 'newValue' that are dominated by + // 'node'. Given: + // x = op(...). + // if cond: + // z = foo(..) + // bar(x) + // else: + // print(x) + // x.replaceAllUsesDominatedByNodeWith(foo, z) would replace bar(x) + // but not print(x) because print is not dominated by foo. + // replaceAllUsesAfterNode does not check domination, so in this example + // it would produce invalid IR. + TORCH_API void replaceAllUsesDominatedByNodeWith( + const Node* node, + Value* newValue); + + TORCH_API Value* copyMetadata(Value* from); + + TORCH_API std::shared_ptr> wrap() { + if (!wrap_) { + wrap_ = std::make_shared>(this); + } + return wrap_; + } + + virtual ~Value() { + if (wrap_) { + wrap_->clear(); + } + } +}; + +struct TORCH_API Node { + AT_DISALLOW_COPY_AND_ASSIGN(Node); + friend struct Graph; + friend struct Block; + friend struct Value; + friend graph_node_list; + friend const_graph_node_list; + friend graph_node_list_iterator; + friend const_graph_node_list_iterator; + + private: + const NodeKind kind_; + std::vector inputs_; + std::vector outputs_; + // subblocks + std::vector blocks_; + Graph* graph_; + Block* owning_block_; + std::optional source_range_; + ScopePtr scope_; + std::optional callstack_; + // Assumes FunctionSchemas are persistent, so we don't manage their lifetime. + // This field is effective a cache that's populated on attribute lookups and + // invalidated every time we perform an operation that could potentially + // change the schema. note: mutable because schema_ is effectively a cache + mutable const Operator* op_; + topo_position_t topo_position_ = 0; + // a managing wrapper for Python to allow invalidation + std::shared_ptr> wrap_; + // Stores the full schema name, if the operator is historic + // When the operator is deprecated or the name of the operator + // is changed, we need to rely on this name + // to retrieve old schemas to successfully apply upgraders + // for this operator. + std::optional historic_schema_name_ = std::nullopt; + + protected: + Node(Graph* graph_, NodeKind kind_); // defined after graph + public: + // Each Node but Return/Param Nodes are associated with exactly one + // place in the Node list of the Graph. The Graph itself is a circular + // doubly-linked list. The Return Node is used as the sentinel for the + // "beginning"/"end" of the list. This means that you can tell when + // you've traversed the entire list without means worrying about null + // pointers. `next_in_graph[0]` is the pointer to the next Node, while + // `next_in_graph[1]` is the pointer to the previous Node. The + // linked list is implemented as an array to allow the same iterator + // class for forward and reversed Node lists. Taken together, this + // list also represents a topological sort of the Nodes in the Graph. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays) + Node* next_in_graph[2] = {nullptr, nullptr}; + + std::shared_ptr> wrap() { + if (!wrap_) { + wrap_ = std::make_shared>(this); + } + return wrap_; + } + + const std::optional getHistoricSchemaName() { + return historic_schema_name_; + } + + void setHistoricSchemaName(const std::string& name) { + historic_schema_name_ = name; + } + + Node*& next() { + return next_in_graph[kNextDirection]; + } + Node*& prev() { + return next_in_graph[kPrevDirection]; + } + Node* const& next() const { + return next_in_graph[kNextDirection]; + } + Node* const& prev() const { + return next_in_graph[kPrevDirection]; + } + + NodeKind kind() const { + return kind_; + } + Node* setSourceRange(SourceRange r) { + source_range_ = std::move(r); + return this; + } + SourceRange sourceRange() const; + + /** + * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. + * Check #87343 for details. + */ + Graph* owningGraph() { + return graph_; + } + const Graph* owningGraph() const { + return graph_; + } + Block* owningBlock() { + return owning_block_; + } + const Block* owningBlock() const { + return owning_block_; + } + ScopePtr scope() { + return scope_; + } + void setScope(ScopePtr scope) { + scope_ = std::move(scope); + } + std::string scopeName() const { + if (!scope_) { + return ""; + } + return scope_->namesFromRoot(); + } + + // Copies the source range, scope and callstack from another node. + Node* copyMetadata(Node* from) { + this->setSourceRange(from->sourceRange()); + this->setScope(from->scope()); + if (auto cs = from->callstack()) { + this->setCallStack(*cs); + } + return this; + } + + std::optional callstack() const { + return callstack_; + } + void setCallStack(InlinedCallStackPtr cs) { + callstack_ = std::move(cs); + } + + // NB: This returns an ArrayRef; that means that it will + // get invalidated if you resize inputs (e.g., using addInput) + // We can't return a std::vector& because there's no + // way to soundly cast to std::vector (an insane + // implementation of std::vector could make this representationally + // different.) + at::ArrayRef inputs() { + return inputs_; + } + at::ArrayRef inputs() const { + // Vectors are not convertible in const-ness of elements, but + // raw pointers are. + return {inputs_.data(), inputs_.size()}; + } + // NB: This returns an ArrayRef; that means that it will + // get invalidated if you resize inputs (e.g., using addInput) + // We can't return a std::vector& because there's no + // way to soundly cast to std::vector (an insane + // implementation of std::vector could make this representationally + // different.) + at::ArrayRef outputs() { + return outputs_; + } + at::ArrayRef outputs() const { + // Vectors are not convertible in const-ness of elements, but + // raw pointers are. + return {outputs_.data(), outputs_.size()}; + } + Value* output(size_t i) const { + return outputs_.at(i); + } + bool hasUses() const { + for (auto o : outputs()) { + if (!o->uses().empty()) { + return true; + } + } + return false; + } + + void replaceAllUsesWith(Node* n); + + // replaces `this` with a new node with the same inputs and outputs + // but a new node symbol. does not destroy `this` + Node* replaceWithNewSymbol(Symbol new_symbol); + + // Checks if this node is dominated by `dominator` which means that + // `dominator` will always be executed before `this` and `dominator` + // is in scope of `this. + bool isDominatedBy(const Node* dominator) const; + + // lots of things like chunk have a single input or single output, so we have + // a helper to make accessing it easier + Value* input() { + AT_ASSERT(inputs_.size() == 1); + return inputs_.at(0); + } + Value* output() { + AT_ASSERT(outputs_.size() == 1); + return outputs_.at(0); + } + const Value* output() const { + AT_ASSERT(outputs_.size() == 1); + return outputs_.at(0); + } + const Value* input() const { + AT_ASSERT(inputs_.size() == 1); + return inputs_.at(0); + } + // Access a particular input. This is a checked index. + Value* input(size_t i) const { + return inputs_.at(i); + } + + bool hasNamedInput(const std::string& unqualName) const; + Value* namedInput(const std::string& unqualName) const; + Value* namedInput(Symbol name) const; + + std::optional get(Symbol name) const; + + template + std::optional get(Symbol name) const { + if (auto v = get(name)) { + return v->template to(); + } + return std::nullopt; + } + + // Returns true if the value of input name is statically known + bool is_constant(Symbol name) const { + return static_cast(get(name)); + } + bool mustBeNone() const; + + bool isNondeterministic() const; + bool hasSideEffects() const; + + // instructions lowered by the interpreter and not run in the optimized graph + bool notExecutedOp() const { + return kind_ == prim::Constant || kind_ == prim::profile || + kind_ == prim::profile_ivalue; + } + + // Graphs + + // Note [Topological invariant] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // We always maintain an up-to-date topological ordering of all nodes via + // the next()/prev() links. All transformations to graphs must preserve + // this topological ordering: for example, it is only valid to 'addInput' + // with an input which is topologically before the current node. + // + // Usually, it is obvious whether or not topological order is maintained; + // for example, if you are adding nodes to the end of the topsort, it's + // impossible for them to refer to inputs that are not in the topsort. + // If it is not obvious, please comment accordingly. + + // Add 'node' as an input to 'this' at the end of existing + // arguments. Returns the added node for ease of chaining. + // + // Given: %3 = f(%1, %2) + // Execute: %3.addInput(%4) + // Result: %3 = f(%1, %2, %4) + Value* addInput(Value* value); + + // Add 'value' as an input to 'this' at the specified position in the + // arguments. Returns the added value for ease of chaining. + Value* insertInput(size_t i, Value* value); + + // Replace the input of 'this' at position 'i' with + // 'newValue', returning the old node. + // + // Given: %3 = f(%1, %2) + // Execute: %3.replaceInput(1, %4) + // Result: %3 = f(%1, %4) + Value* replaceInput(size_t i, Value* newValue); + + // Replace all occurrences of 'from' in the inputs of this + // node with 'to'. Corresponds to llvm's replaceUsesOfWith. + // + // Given: %3 = f(%1, %2, %1) + // Execute: %3.replaceInputWith(%1, %4) + // Result: %3 = f(%4, %2, %4) + void replaceInputWith(Value* from, Value* to); + + Value* addOutput(); + + Value* insertOutput(size_t i); + + void eraseOutput(size_t i); + + Block* addBlock(); + void eraseBlock(size_t i); + + // Each Node can have a list of subblocks. These are used to define structured + // nested control flow operators such as If and Loop. + // The meaning of a block is specific to the kind of node it is in, but + // all blocks share these semantics: + // * Nested lexical scoping: If a node 'Parent' has a subblock which contains + // a node 'Child', Child can use any value that was in scope for the Parent + // node in addition to any values defined before 'Child' in the subblock. + // * The list of inputs to the block are in scope for the duration of the + // block + // * the outputs of the Parent node are not in scope for the subblocks + // Typically the inputs to a block that represents control flow act as + // as the equivalents phi-nodes in standard SSA form, + // defining a new Value to represent any term that has multiple + // definitions depending on how control flowed. Outputs of the node containing + // control flow serve a similiar purpose defining new values for variables + // that would have different definitions depending on which way control + // flowed. + + at::ArrayRef blocks() { + return blocks_; + } + at::ArrayRef blocks() const { + // Vectors are not convertible in const-ness of elements, but + // raw pointers are. + return {blocks_.data(), blocks_.size()}; + } + + // Is 'this' before 'n' in the topological order? + bool isBefore(const Node* n) const; + + // Is 'this' after 'n' in the topological order? + bool isAfter(const Node* n) const; + + // Insert unattached 'this' node before 'n' in the topological order. + // Returns this (for chaining). + // + // Given: %3 = f(%1, %2) + // %4 = g(%3) + // and unattached: %5 = h(%1) + // Execute: %5.insertBefore(%4) + // Result: %3 = f(%1, %2) + // %5 = h(%1) + // %4 = g(%3) + Node* insertBefore(Node* n); + + // Insert unattached 'this' node after 'n' in the topological order. + // Returns this (for chaining). + // + // Given: %3 = f(%1, %2) + // %4 = g(%3) + // and unattached: %5 = h(%1) + // Execute: %5.insertAfter(%4) + // Result: %3 = f(%1, %2) + // %4 = g(%3) + // %5 = h(%1) + Node* insertAfter(Node* n); + + // Move 'this' (already in the graph) after 'n' in the topological order. + // + // NOTE: Does not check that value dependencies are preserved, see + // AliasDb::moveAfterTopologicallyValid + // + // Given: %2 = f(%1) + // %3 = g(%1) + // Execute: %2.moveAfter(%3) + // Result: %3 = g(%1) + // %2 = f(%1) + // + void moveAfter(Node* n); + + // Move a node 'n' (already in the graph) before 'this' in the topological + // order. + // + // NOTE: Does not check that value dependencies are preserved, see + // AliasDb::moveBeforeTopologicallyValid + // + // Given: %2 = f(%1) + // %3 = g(%1) + // Execute: %3.moveBefore(%2) + // Result: %3 = g(%1) + // %2 = f(%1) + void moveBefore(Node* n); + + // Remove the input at 'i' from this node. + // + // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling + // removeInput. + // + // Given: %3 = f(%1, %2) + // Execute: %3.removeInput(1) + // Result: %3 = f(%1) + void removeInput(size_t i); + + // Remove all inputs from a node. + // + // Given: %3 = f(%1, %2) + // Execute: %3.removeAllInputs() + // Result: %3 = f() + void removeAllInputs(); + + // Remove all outputs from a node. + // + // Given: %1, %2 = f() + // Execute:removeAllInputs() + // Result: = f() + void removeAllOutputs(); + + // Rearrange the ordering of inputs or outputs of a node + // Given: %3 = f(%1, %2) + // Execute: %3.permuteInputs({1, 0}) + // Result: %3 = f(%2, %1) + // Each index must appear exactly once + void permuteInputs(const std::vector& new_inputs); + void permuteOutputs(const std::vector& new_inputs); + + // iterators of the node list starting at this node + // useful for resuming a search starting at this node + inline graph_node_list_iterator iterator() { + return {this, 0}; + } + inline graph_node_list_iterator reverseIterator() { + return iterator().reverse(); + } + inline const_graph_node_list_iterator iterator() const { + return {this, 0}; + } + inline const_graph_node_list_iterator reverseIterator() const { + return iterator().reverse(); + } + + // Remove 'this' from the instruction list and deallocate it. + // + // Invariant: no outputs of 'this' may have any uses. + // + // Given: %2 = f(%1) + // %3 = g(%1) + // Execute: %2.destroy() + // Result: %3 = g(%1) + void destroy(); + + // Dynamically cast this node to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid.. + // + // Example usage: if(auto s = n.cast to stride as arg VarHandle + std::unordered_map, VarHandle, SmallSizeTPairHash> + strideArgToVar_; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides_; + + // Memory layout to be propagated with fusion group + MemoryLayoutPolicy memory_layout_policy_ = MemoryLayoutPolicy::kContiguous; +}; + +TORCH_API int& getTECudaPointwiseLoopLevels(); +TORCH_API int& getTECudaPointwiseBlockCount(); +TORCH_API int& getTECudaPointwiseBlockSize(); +TORCH_API bool& getTEGenerateBlockCode(); +TORCH_API bool& getTEMustUseLLVMOnCPU(); +TORCH_API bool fallbackAllowed(); +TORCH_API bool setFallbackAllowed(bool value); +TORCH_API bool& getCatWoConditionals(); +TORCH_API bool& getOptConditionals(); + +TORCH_API std::optional pickDeviceType( + const at::ArrayRef& inputs); + +bool isContiguous( + const torch::jit::Value* v, + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/llvm_codegen.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/llvm_codegen.h new file mode 100644 index 0000000000000000000000000000000000000000..1d96b4dd0467e3e807aa6a4282c745f2daa7f637 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/llvm_codegen.h @@ -0,0 +1,143 @@ +#pragma once + +#ifdef TORCH_ENABLE_LLVM +#include + +#include +#include +#include + +#include + +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +class LLVMCodeGenImpl; +class LLVMCodeGenCallee; + +class TORCH_API LLVMCodeGen : public CodeGen { + public: + explicit LLVMCodeGen( + StmtPtr stmt, + const std::vector& args, + at::Device device = at::kCPU, + const std::string& kernel_func_name = "func", + Dtype dtype = kInt, + std::optional triple = std::nullopt, + std::optional cpu = std::nullopt, + std::optional attrs = std::nullopt); + explicit LLVMCodeGen(StmtPtr stmt); + + LLVMCodeGen() = delete; + ~LLVMCodeGen() override; + + // Cleans up all the memory used during LLVM code generation pass except + // the generated kernel. After calling this method, users should not call + // methods like `getCodeText` that require the LLVMCodeGenImpl data. However, + // users can continue to call this kernel using `call` and `call_raw`. + void cleanup_memory(); + + TORCH_API void call(const std::vector& args) override; + TORCH_API void call_raw(const std::vector& args) override; + TORCH_API void call_with_numel(void** args, int64_t numel) override; + + at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) override; + + template + T value() { + return value(nullptr); + } + + template + T value(std::vector& args) { + return value(args.data()); + } + + template + T value(void** args) { + T (*fp)(void**) = (T(*)(void**))getKernelAddress(callee_.get()); + T rv = fp(args); + return rv; + } + + std::string getCodeText(const std::string& attr = "") override; + + private: + void* getKernelAddress(LLVMCodeGenCallee* callee); + + std::unique_ptr callee_; + std::unique_ptr impl_; +}; + +struct TORCH_API LLVMCodeGenBuilder { + using BufferArg = CodeGen::BufferArg; + + LLVMCodeGenBuilder(StmtPtr stmt, std::vector args) + : stmt_(stmt), args_(std::move(args)) {} + + LLVMCodeGenBuilder& device(at::Device device) { + device_ = device; + return *this; + } + + LLVMCodeGenBuilder& kernelFuncName(std::string name) { + kernelFuncName_ = std::move(name); + return *this; + } + + LLVMCodeGenBuilder& dtype(Dtype d) { + dtype_ = d; + return *this; + } + + LLVMCodeGenBuilder& triple(std::string triple) { + triple_ = std::move(triple); + return *this; + } + + LLVMCodeGenBuilder& cpu(std::string cpu) { + cpu_ = std::move(cpu); + return *this; + } + + LLVMCodeGenBuilder& attrs(std::string attrs) { + attrs_ = std::move(attrs); + return *this; + } + + std::unique_ptr build() { + return std::make_unique( + stmt_, args_, device_, kernelFuncName_, dtype_, triple_, cpu_, attrs_); + } + + private: + StmtPtr stmt_; + std::vector args_; + at::Device device_ = at::kCPU; + std::string kernelFuncName_ = "func"; + Dtype dtype_ = kInt; + std::optional triple_ = std::nullopt; + std::optional cpu_ = std::nullopt; + std::optional attrs_ = std::nullopt; +}; + +TORCH_API std::optional& LLVMTargetTriple(); +TORCH_API std::optional& LLVMTargetCPU(); +TORCH_API std::optional& LLVMTargetAttrs(); +TORCH_API bool& LLVMAOTWorkflow(); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +#endif // TORCH_ENABLE_LLVM diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/llvm_jit.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/llvm_jit.h new file mode 100644 index 0000000000000000000000000000000000000000..beadbdd5e537e7f054a70200af16a2f7a67c21fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -0,0 +1,77 @@ +#pragma once + +#ifdef TORCH_ENABLE_LLVM +#include +#include +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") +#include +C10_DIAGNOSTIC_POP() +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +inline std::string formatError(llvm::Error&& err, const char* msg) { + static constexpr const char* defaultErrorMsg = + "Unexpected failure in LLVM JIT"; + std::string errorMsg(msg ? msg : defaultErrorMsg); + llvm::raw_string_ostream ss(errorMsg); + ss << ": " << err; + return ss.str(); +} + +template +T assertSuccess(llvm::Expected valOrErr, const char* msg = nullptr) { + TORCH_INTERNAL_ASSERT(valOrErr, formatError(valOrErr.takeError(), msg)); + return std::move(*valOrErr); +} + +inline void assertSuccess(llvm::Error err, const char* msg = nullptr) { + TORCH_INTERNAL_ASSERT(!err, formatError(std::move(err), msg)); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +namespace llvm { +namespace orc { + +class PytorchLLVMJITImpl; + +class TORCH_API PytorchLLVMJIT { + public: + PytorchLLVMJIT( + std::optional triple, + std::optional cpu, + std::optional attrs); + ~PytorchLLVMJIT(); + + void addModule(std::unique_ptr M, std::unique_ptr C); + + JITSymbol findSymbol(const std::string Name); + + bool hasSymbol(const std::string& Name); + + TargetMachine& getTargetMachine(); + + const DataLayout& getDataLayout(); + + private: + // Use the PImpl idiom here to hide the no-rtti parts of the JIT structure. + std::unique_ptr impl_; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // ENABLE LLVM diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/loopnest.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/loopnest.h new file mode 100644 index 0000000000000000000000000000000000000000..20614fea0bad9dafaf9190b1a931a8e24ed2b030 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/loopnest.h @@ -0,0 +1,616 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::jit::tensorexpr { + +class Expr; +class Var; +class Buf; +class Tensor; +class Function; +class Stmt; +class For; +class Block; +class Store; +class Dtype; + +class TORCH_API LoopNest { + public: + // A constructor for building a LoopNest from a list of Tensors + LoopNest( + const std::vector& output_tensors, + const std::vector& tensors_to_compute); + + // A convenience constructor for the case when all tensors are output tensors + LoopNest(const std::vector& output_tensors); + + // A constructor for building a LoopNest from an Stmt and a list of output + // buffers. + LoopNest(StmtPtr stmt, std::unordered_set output_bufs); + + // A constructor for building a LoopNest from another loopnest. It clones the + // other loopnest's stmt. + LoopNest(const LoopNest& other); + + StmtPtr root_stmt() const { + return root_stmt_; + } + + std::vector getLoopStmtsFor(const Tensor&) const; + std::vector getLoopStmtsFor(const BufPtr&) const; + std::vector getLoopStmtsFor(StmtPtr) const; + StmtPtr getLoopBodyFor(const Tensor&) const; + StmtPtr getLoopBodyFor(BufPtr) const; + + // Returns the For stmt indexed by 'indices' in the 'root' For stmt. + //'indices' indicates the path to the returned loop from 'root' in AST, e.g., + // + // root: for(int i...){ + // j_loop: for (int j...){ + // k1_loop: for (int k1...){ + // A[i, j, k1] = .... + // } + // B[i, j] = ... + // k2_loop: for (int k2...){ + // A[i, j, k2] = ... + // } + // } + // } + // + // the path from 'root' to 'j_loop' is [0] + // the path from 'root' to 'k1_loop' is [0, 0] + // the path from 'root' to 'k2_loop' is [0, 2] + ForPtr getLoopAt(ForPtr root, const std::vector& indices) const; + + // Returns the For stmt that is immediately enclosing the given stmt. + static ForPtr getParentLoop(const StmtPtr& st); + + // Returns the list of For stmts corresponding to the loopnest that is + // enclosing the given stmt. + static std::vector getEnclosingLoopNest(const StmtPtr& st); + + // Returns a list of all Stmts that write to the given buf. + std::vector getAllWritesToBuf(BufPtr) const; + + // The following methods return the For loops that contain writes to + // the given buf. + // + // For example, consider the following code: + // for i1 + // for j1 + // a[i1,j1] = + // for i2 + // for j2 + // for k2 + // a[i2,j2] = + // for j3 + // a[i2,j3] = + + // Returns a list of For loops which directly contain a Stmt that writes + // to buf. + // For the above example: + // getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3} + std::vector getAllInnermostLoopsWritingToBuf(BufPtr) const; + + // Returns a list of For loopnests which contain a Stmt that writes to + // the given buf. Each loopnest here is a vector For loops. + // For the above example: + // getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}} + std::vector> getAllLoopNestsWritingToBuf(BufPtr) const; + + StmtPtr simplify(); + + // Sanitize variables and buffer names. + // The pass assigns predefined names for loop index variables + // (i,j,k,l,m,n,o,p,i1,j1,k1,...) and ensures these names are not conflicting + // anywhere. It also removes duplicates from other Buf nad Var names as well + // as replaces illegal characters in them with underscores. + // + // Note: since it's currently technically possible to use the same variable + // as index in two different loops, this transformation finds such cases and + // introduces new variables to avoid duplication. + static StmtPtr sanitizeNames(StmtPtr s); + + bool computeInline(const StmtPtr& s); + bool computeInline(const BufPtr& b); + void inlineIntermediateBufs(bool allow_duplicated_work); + + // Optimizes conditionals. + // + // Currently, only the following pattern of conditionals is optimized. + // This corresponds to the conditional format that is generated to handle + // `aten::cat` op. + // + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) + // } + // + // Constraints that must be satisfied for this optimization: + // * All conditions should be of the form "var < expr". + // * All conditions should have the same variable, say v. + // * The condition variable found should be the same as the inner-most + // loop variable. TODO: Remove this constraint. + // * If there are multiple stores that contain conditionals using the same + // loop variable, only the first conditional will be optimized. + // TODO: Remove this constraint. + bool optimizeConditionals(); + + // Splits the given loop into 2 nested loops with the given factor as the + // inner loop bound. If the factor does not evenly divide the loop bound, + // then the remaining iterations are extracted into a tail loop that is + // added after the given loop. + // + // For example, consider the following code: + // for (int i = 0; i < 100; ++i) { + // A[i] = + // } + // + // splitWithTail(i, 8, ...) will result in: + // for (int i_outer = 0; i_outer < 12; ++i_outer) { + // for (int i_inner = 0; i_inner < 8; ++i_inner) { + // A[i_outer * 8 + i_inner] = + // } + // } + // for (int i_tail = 0; i_tail < 4; ++i_tail) { + // A[i_tail + 96] = + // } + // + // The given loop will be transformed to the outer loop after splitting. + // So, the pointer to the input loop should be valid after splitting and + // will point to the outer loop. The `inner` and `tail` parameters will be + // set to point to the inner and tail loops that are generated. + static void splitWithTail( + const ForPtr& f, + int factor, + ForPtr* inner, + ForPtr* tail); + // A convenience wrapper when the caller does not need to access the + // split loops. + static void splitWithTail(const ForPtr& f, int factor); + + // Splits the given loop into 2 nested loops with the given factor as the + // inner loop bound. If the factor does not evenly divide the loop bound, + // then a conditional is inserted into the body to handle the remaining + // iterations appropriately. + // + // For example, consider the following code: + // for (int i = 0; i < 100; ++i) { + // A[i] = + // } + // + // splitWithMask(i, 8, ...) will result in: + // for (int i_outer = 0; i_outer < 13; ++i_outer) { + // for (int i_inner = 0; i_inner < 8; ++i_inner) { + // if (i_outer * 8 + i_inner < 100) { + // A[i_outer * 8 + i_inner] = + // } + // } + // } + // + // The given loop will be transformed to the outer loop after splitting. + // So, the pointer to the input loop should be valid after splitting and + // will point to the outer loop. The `inner` parameter will be set to point + // to the inner loop that is generated. + static void splitWithMask(const ForPtr& f, int factor, ForPtr* inner); + // A convenience wrapper when the caller does not need to access the + // split loops. + static void splitWithMask(const ForPtr& f, int factor); + + // The following methods support loop distribution. + // For example, consider the following code. This will be used to + // demonstrate the methods below. + // + // S0: for m + // S1: for i + // S2: A[i] = 0 + // S3: for j + // S4: A[i] = A[i] + + // S5: B[i] = A[i] + // S6: for k + // S7: B[i] = B[i] + + + // This method distributes the given loop over its body by splitting + // after every given pivot stmt. + // + // NOTE: Pivot stmts that are not in the given loop's body will be ignored. + // + // For the above example: + // distributeLoop(S1, {S3, S5}) + // will result in: + // S0: for m + // S1: for i + // S2: A[i] = 0 + // S3: for j + // S4: A[i] = A[i] + + // : for i + // S5: B[i] = A[i] + // : for i + // S6: for k + // S7: B[i] = B[i] + + static std::vector distributeLoop( + const ForPtr& loop, + const std::unordered_set& pivots); + + // This method distributes the given loop over every stmt in its body. + // + // For the above example: + // distributeLoop(S1) + // will result in: + // S0: for m + // S1: for i + // S2: A[i] = 0 + // : for i + // S3: for j + // S4: A[i] = A[i] + + // : for i + // S5: B[i] = A[i] + // : for i + // S6: for k + // S7: B[i] = B[i] + + static std::vector distributeLoop(const ForPtr& loop); + // Same as above, but also distribute parent loops. + // Returns the result of distributing the outermost loop. + // + // For the above example: + // distributeLoopAndParents(S1) will result in: + // S0: for m + // S1: for i + // S2: A[i] = 0 + // : for m + // : for i + // S3: for j + // S4: A[i] = A[i] + + // : for m + // : for i + // S5: B[i] = A[i] + // : for m + // : for i + // S6: for k + // S7: B[i] = B[i] + + static std::vector distributeLoopAndParents(const ForPtr& loop); + + // This method distributes the given loop over its body by splitting + // after every For stmt in its body. + // + // For the above example: + // distributeLoopOverInnerLoops(S1) + // will result in: + // S0: for m + // S1: for i + // S2: A[i] = 0 + // S3: for j + // S4: A[i] = A[i] + + // : for i + // S5: B[i] = A[i] + // S6: for k + // S7: B[i] = B[i] + + static std::vector distributeLoopOverInnerLoops(const ForPtr& loop); + // Same as above, but also distribute parent loops. + // Returns the result of distributing the outermost loop. + // + // For the above example: + // distributeLoopAndParentsOverInnerLoops(S1) + // will result in: + // S0: for m + // S1: for i + // S2: A[i] = 0 + // S3: for j + // S4: A[i] = A[i] + + // : for m + // : for i + // S5: B[i] = A[i] + // S6: for k + // S7: B[i] = B[i] + + static std::vector distributeLoopAndParentsOverInnerLoops( + const ForPtr& loop); + + // This method performs loop fusion. + // For example, consider the following code. + // + // S1: for m + // S2: A[m] = 0 + // S3: for j + // S4: A[m] = A[m] + + // S5: for n + // S5: B[n] = A[n] + // S6: for k + // S7: B[n] = B[n] + + // + // fuseLoops({S1, S5}), will return the following loop: + // S1: for m + // S2: A[m] = 0 + // S3: for j + // S4: A[m] = A[m] + + // S5: B[m] = A[m] + // S6: for k + // S7: B[m] = B[m] + + // + // This transformation is unsafe as it simply add all loops into the body of + // the first loop for fusion without correctness checks. + // + // Below are the two requirements to apply unsafeFuseLoops: + // * All the loops have the same parent. + // * There are no statements between these loops in their parent body. + static bool unsafeFuseLoops(const std::vector& loops, ForPtr* fused); + + // Loop fusion is done only when all the conditions below are satisfied. + // * All the loops have the same parent. + // * There are no statements between these loops in their parent body. + // * The start bounds are the same for all loops. + // * The stop bounds are the same for all loops. + // * Fusing the loops does not violate or add any dependencies. + static bool fuseLoops(const std::vector& loops, ForPtr* fused); + + static void reorderAxis(const ForPtr& a, const ForPtr& b); + + // Reorder the given list of loops according to the permutation specified. + // Here `permutation[i]` represents the position of the loop in the input + // which will end up at position `i` after the reorder. + // + // For example, consider the following code: + // for p + // for q + // for r + // for s + // A[p,q,r,s] = + // + // reorder({p, q, r, s}, {2, 3, 0, 1}) will return the list of loops in the + // following form: + // for r + // for s + // for p + // for q + // A[p,q,r,s] = + static std::vector reorder( + const std::vector& loops, + const std::vector& permutation); + + // Tile takes a 2d domain (x, y) and splits it into small rectangular blocks + // each with shape (x_factor, y_factor). The traversal over the domain turns + // into an outer iteration over the blocks and an inner traversal over all + // points in the block. + // Note that if x dim % x_factor or y dim % y_factor does not equal to 0, the + // loop body will generate corresponding tailing loops. + // The transformation is in-place and returns 'xtail'. + // + // For example, consider the following code: + // for i: [0, 64) + // for j: [0, 64) + // for k: [0, 32) + // A[i, j] = B[i, k] + C[j, k] + // + // tile(i, j, 4, 8) will transform "i" for-stmt into the following nested + // loop: + // for i_outer: [0, 16) + // for j_outer: [0, 8) + // for i_inner: [0, 4) + // for j_inner: [0, 8) + // for k: [0, 32) + // A[i_outer * 4 + i_inner, j_outer * 8 + j_inner] = + // B[i_outer * 4 + i_inner, k] + C[j_outer * 8 + j_inner, k] + // + // tile(i, j, 4, 9) will transform "i" for-stmt into the following nested + // loop: + // for i_outer: [0, 16) + // for j_outer: [0, 7) + // for i_inner: [0, 4) + // for j_inner: [0, 9) + // for k: (0, 32) + // A[i_outer * 4 + i_inner, j_outer * 9 + j_inner] = + // B[i_outer * 4 + i_inner, k] + C[j_outer * 9 + j_inner, k] + // for j_tail: [0, 1) + // for i_inner: [0, 4) + // for k: (0, 32) + // A[i_outer * 4 + i_inner, 7 * 9 + j_tail] = + // B[i_outer * 4 + i_inner, k] + C[7 * 9 + j_tail, k] + ForPtr tile(const ForPtr& x, const ForPtr& y, int x_factor, int y_factor); + + // Returns true if the given loops are perfectly nested, i.e., every loop + // (except the innermost) should have exactly one statement in its body + // and that statement must be the next inner loop. + static bool areLoopsPerfectlyNested(const std::vector& loops); + + // Returns true if the given loop has a loop-carried dependence. + static bool hasLoopCarriedDependence(const ForPtr& loop); + + // Unrolls all the iterations of the given loop. + // Requires that the loop bounds are constant. + static void fullUnroll(const ForPtr& f, StmtPtr* unrolled); + static void fullUnroll(const ForPtr& f); + + // Unrolls the given loop for the specified factor. + // This does not require constant bounds for the loop being unrolled. + static void unroll(const ForPtr& f, int factor, ForPtr* tail); + static void unroll(const ForPtr& f, int factor); + + static bool normalize(const ForPtr& f); + static bool isNormalized(const ForPtr& f); + + static bool flatten(const std::vector& f, ForPtr* flattened); + static bool flatten(const std::vector& f); + + // Compresses the given buffer based on its use in the given Stmts. + // + // NOTE: This API assumes that there are no accesses to the given buffer + // outside the given statement. So, this should be called with the entire + // kernel statement to avoid incorrect buffer compressions. + // + // For example, given the input: + // + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // } + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[i,j] + A[i, j+1] + // } + // } + // + // compressBuffer(A, ...) will compress buffer A from + // [100, 200] to [1, 200] and modify the code as follows: + // + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[0,j] = sin(i*j) + // } + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[0,j] + A[0, j+1] + // } + // } + static void compressBuffer(const BufPtr& buf, const StmtPtr& stmt); + + // Compresses all buffers in the given statement. + // + // NOTE: This API assumes that there are no accesses to buffers outside + // the given statement. So, this should be called with the entire + // kernel statement to avoid incorrect buffer compressions. + // + // TODO: Add an IR verifier check to detect invalidly compressed buffers. + static void compressAllBuffers(const StmtPtr& stmt); + + // Get 'num' loops from the loopnest starting at 'f'. + static std::vector getLoopStmtsInLoopNest( + const ForPtr& f, + size_t num); + + // LoopOptions are propagated to tail. + static void sliceHead( + const ForPtr& f, + int factor, + ForPtr* head, + ForPtr* tail); + static void sliceHead(const ForPtr& f, int factor); + // LoopOptions are propagated to head. + static void sliceTail( + const ForPtr& f, + int factor, + ForPtr* head, + ForPtr* tail); + static void sliceTail(const ForPtr& f, int factor); + + using AccessResult = std::pair; + // Insert a cache for the consumer's usages of the buffer produced in + // consumer, and redirect reads and writes in the consumer to that cache. + // Returns a pair of the new cache buffer, and the new rewritten consumer. + static AccessResult cacheAccesses( + const BufPtr& producer, + const std::string& name, + const StmtPtr& consumer); + + // Insert a temporary computation of statement S in the scope of loop AT. + // S is assumed to be a Store or a Block containing a Store. Along with the + // computation itself, this transformation inserts Alloc/Free statements for + // the temporary buffer used in the computation. + static void computeAt(const StmtPtr& s, const ForPtr& at); + + // Rfactor a reduction axis into a normal axis. + // + // Requirements: + // * S is the reduction store + // * S is the only statement in the innermost loop + // * There is at least two reduction arguments in S + // * OUTER_REDUCTION_FOR loop corresponds to the outermost reduction variable + // used in the store and all other reduction variables are index variables of + // children loops of OUTER_REDUCTION_FOR + // * OUTER_REDUCTION_FOR is a perfect loop nest, i.e. it has only loops + // corresponding to the other reduction variables and the store, nested into + // each other + // + // What it does: + // * Introduce a new buffer with an extra dimension of a size equal to the + // span of the loop OUTER_REDUCTION_FOR (the new buffer is returned via + // RFAC_BUF_PTR) + // * Insert an initialization store for the new buffer in + // OUTER_REDUCTION_FOR before its nested loop + // * Replace the reduction store to the original buffer with the reduction + // store to the temp buffer, removing the index var of OUTER_REDUCTION_FOR + // from reduction arguments + // * Insert a final reduction store over the extra dimension of the new + // buffer to the original buffer + // * Returns TRUE if the transformation succeeded and FALSE otherwise + // + // Example: + // Original IR: + // S1: for i # normal axis + // S2: X[i] = 0 + // S3: for j # reduction axis + // S4: for k # reduction axis + // S5: X[i] = ReduceOp(X[i] + Y[i,j,k], reduce_axis={j,k}) + // + // After RFACTOR(S5, S3) + // S1: for i # normal axis + // S2: X[i] = 0 + // S3: for j # reduction axis for X, normal axis for X_rfac + // X_rfac[i,j] = 0 + // S4: for k # reduction axis + // X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k}) + // X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j}) + static bool rfactor(const StmtPtr& s, const ForPtr& outer_reduction_for); + static bool rfactor( + const StmtPtr& s, + const ForPtr& outer_reduction_for, + BufPtr* rfac_buf_ptr); + + // Vectorize the given loop. This method requires that the given loop + // does not perform a reduction. + // It returns true if vectorization is successful and false otherwise. + static bool vectorize(const ForPtr&); + + // Find the inner-most loops and vectorize them. Currently, this only works + // for the LLVM backend, when no reductions are involved. + void vectorizeInnerLoops(); + + void eliminateDeadStores(); + + void prepareForCodegen(); + + const std::unordered_set getInputBufs() const; + const std::unordered_set getOutputBufs() const { + return output_bufs_; + } + std::vector getIntermediateBufs() const; + + // Finds which is the outer For between a and b for loops. If neither of the 2 + // Fors is an ancestor of the other, it returns nullptr. + static ForPtr findOuterFor(ForPtr a, ForPtr b); + + private: + void initialize( + const std::vector& output_tensors, + const std::vector& tensors_to_compute); + + StmtPtr root_stmt_; + + std::unordered_set output_bufs_; +}; + +TORCH_API StmtPtr FlattenIndexes(const StmtPtr& s); + +// TODO: Revisit this once we decide on how dependencies analysis should look +// like. Maybe we would choose to use a different API and BufUse would be +// removed, or if we decide to keep it we need to properly document its API. +struct BufLoadOrStoreUse { + StmtPtr s; + bool isStore; +}; + +/* + * Returns a map ( Buf -> uses of this Buf), uses are represented as vectors of + * BufUse elements, which are StmtPtr and a bool isStore flag. The order of uses + * in the vectors reflects the order in which the uses appear in the given + * statement. + */ +std::unordered_map> findLoadOrStoreUses( + const StmtPtr& s); + +// replaces all invalid characters with underscore +TORCH_API std::string sanitizeName(const std::string& input_name); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/loopnest_randomization.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/loopnest_randomization.h new file mode 100644 index 0000000000000000000000000000000000000000..3f5f687416cb752aa32319680d49c928e1549fe4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/loopnest_randomization.h @@ -0,0 +1,9 @@ +#pragma once + +namespace torch::jit::tensorexpr { + +// Applies a series of loop optimizations chosen randomly. This is only for +// testing purposes. This allows automatic stress testing of NNC loop +// transformations. +void loopnestRandomization(int64_t seed, LoopNest& l); +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/lowerings.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/lowerings.h new file mode 100644 index 0000000000000000000000000000000000000000..2f53a2c791d4d78247eb3bf03a7a11a6f68ab97e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/lowerings.h @@ -0,0 +1,45 @@ +// This file defines classes for registering standard lowerings from JIT to TE +// IR. +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::jit::tensorexpr { + +using ArgNone = std::monostate; +using BufList = std::vector; +using DoubleList = std::vector; +using IntList = std::vector; +using ArgValue = std::variant< + tensorexpr::BufHandle, + tensorexpr::VarHandle, + double, + int64_t, + bool, + BufList, + DoubleList, + IntList, + std::string, + ArgNone>; + +using NNCLoweringFunction = std::function&, + const std::vector&, + const std::vector&, + const std::optional&, + at::Device)>; + +TORCH_API FunctionSchemaMap& getNNCLoweringRegistry(); +TORCH_API NNCLoweringFunction getStandardLoweringFor(const std::string& op); + +struct RegisterNNCLoweringsFunction { + RegisterNNCLoweringsFunction( + const std::vector& schemas, + const NNCLoweringFunction& fn); +}; + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/mem_dependency_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..39202f487ad2de57b5e566bcd1e5cb9ac0fdd50e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -0,0 +1,409 @@ +#pragma once +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::jit::tensorexpr::analysis { + +enum class AccessType { + Input, + Output, + Load, + Store, + Call, + AtomicAdd, + Alloc, + Free +}; +const char* AccessToString(AccessType a); + +class AccessInfo; +using DependencySet = std::unordered_set>; + +/* AccessInfo + * + * Represents a single bounded memory access to a buffer, for instance a Load or + * a Store. Holds information relating to the specific access and links to + * connected accesses in the dependency graph. + */ +class TORCH_API AccessInfo { + public: + AccessInfo( + size_t id, + AccessType type, + StmtPtr stmt, + VarPtr var, + IndexBounds bounds) + : id_(id), + type_(type), + stmt_(std::move(stmt)), + expr_(nullptr), + var_(std::move(var)), + bounds_(std::move(bounds)) {} + + AccessInfo( + size_t id, + AccessType type, + ExprPtr expr, + StmtPtr stmt, + VarPtr var, + IndexBounds bounds) + : id_(id), + type_(type), + stmt_(std::move(stmt)), + expr_(std::move(expr)), + var_(std::move(var)), + bounds_(std::move(bounds)) {} + + // Id is a unique int representing the order this access occurred in the + // graph. + size_t id() const { + return id_; + } + + // The type of the access (Load, Store, etc). + AccessType type() const { + return type_; + } + + // The enclosing Stmt this access represents. E.g. if this is a Store then + // Stmt is the Store itself, while if the access is caused by an Expr, this is + // the most immediate parent Stmt. + StmtPtr stmt() const { + return stmt_; + } + + // If the access is represented by an Expr (such as Load or Call) then this is + // it, otherwise it's nullptr. + ExprPtr expr() const { + return expr_; + } + + // The Var representing the underlying Buffer. + VarPtr var() const { + return var_; + } + + // A vector of Bounds representing the start and end expression for each + // dimension. + IndexBounds& bounds() { + return bounds_; + } + + // Each access that this depends upon, + // eg. if this is a Load, then it contains every Store that immediately + // contributes to a load of the bounds. + // or: if this is a Store, it contains all reads on the RHS of the Store. + const std::map>& dependencies() const { + return dependencies_; + } + + // Each access that depends on this one. + // ie. this access is present in the dependencies map of all accesses that are + // dependent. + std::map> dependents() const { + std::map> res; + for (const auto& kv : dependents_) { + res.emplace(kv.first, kv.second.lock()); + } + return res; + } + + // Returns the symbolic expression of the indices of this access. + std::vector getIndices() const; + + // Establishes a dependency or dependent relationship with another access. + void addDependency(const std::shared_ptr& write); + void addDependent(const std::shared_ptr& read); + + // helper for checking dependencies. + bool hasDependency(const std::shared_ptr& info) const; + + // Returns the set of all nodes that are direct (immediate) dependencies of + // this access. + DependencySet getDirectDependencies(); + // likewise, returns all nodes that directly depend on this one. + DependencySet getDirectDependents(); + + // Returns the full list of all nodes in the graph that this access depends + // on, and all nodes they depend on, and so forth, back to the inputs. + DependencySet getIndirectDependencies(); + // likewise, returns the full list of all nodes that depend on this node, and + // all nodes that depend on those nodes and so on down to the outputs. + DependencySet getIndirectDependents(); + + // Does this access represent a read of memory (Load, ReduceOp, Call, etc). + bool isRead() const; + // Does this access represent a write of memory (Store, etc). + bool isWrite() const; + + // Helpers for dumping accesses in various formats. + void print() const; + void dumpDOT(std::ostream& os) const; + const char* AccessTypeColour() const; + + private: + size_t id_; + AccessType type_; + StmtPtr stmt_; + ExprPtr expr_; + VarPtr var_; + IndexBounds bounds_; + + // Yes these should be sorted. + std::map> dependencies_; + std::map> dependents_; +}; + +using VarBoundMap = std::unordered_map; + +/* MemDependencyChecker analyses a IR fragment and builds a dependency graph of + * accesses contained within. + * + * It's possible to retrieve the entire graph in node-object form, or can be + * used as an oracle for answering dependency questions. e.g: + * + * analyzer.hasIndirectDependency(BufA, BufB); or, + * analyzer.hasDirectDependency(LoadA, StoreB); + */ +class TORCH_API MemDependencyChecker : public IRVisitor { + struct Scope; + + public: + MemDependencyChecker(); + MemDependencyChecker( + const std::unordered_set& inputs, + const std::unordered_set& outputs); + MemDependencyChecker( + const std::vector& inputs, + const std::vector& outputs); + + ~MemDependencyChecker() override = default; + + // Whether or not to allow loop execution order to influence dependency + // calculation. If the loop may later be parallelized you don't want this. + bool allowLoopExecutionOrderAnalysis(bool allow = true); + + // Dependency Checking API. + // The goal is to have enough overloads here so you don't really have to think + // about it. + + // Returns true if any read in A has a direct dependence on a write in B. + bool dependsDirectly(const StmtPtr& A, const StmtPtr& B); + bool dependsDirectly(const ExprPtr& A, const StmtPtr& B); + + // Returns true of the output depends directly on a write contained in B. + bool dependsDirectly(const BufPtr& output, const StmtPtr& B); + + // Returns true if a read in A depends directly on the provided input. + bool dependsDirectly(const StmtPtr& A, const BufPtr& input); + bool dependsDirectly(const ExprPtr& A, const BufPtr& input); + + // Outputs/inputs cannot depend directly. + + // Returns true if the access A has B as an immediate dependency. + bool dependsDirectly( + const std::shared_ptr& A, + const std::shared_ptr& B); + + // Returns true if any read in A has an ancestor write contained in B. + bool dependsIndirectly(const StmtPtr& A, const StmtPtr& B); + bool dependsIndirectly(const ExprPtr& A, const StmtPtr& B); + + // Returns true of the output depends indirectly on a write contained in B. + bool dependsIndirectly(const BufPtr& output, const StmtPtr& B); + + // Returns true if a read in A depends indirectly on the provided input. + bool dependsIndirectly(const StmtPtr& A, const BufPtr& input); + bool dependsIndirectly(const ExprPtr& A, const BufPtr& input); + + // returns true if the output uses any load of the input. + bool dependsIndirectly(const BufPtr& output, const BufPtr& input); + + // Returns true if the access A has a dependency chain to access B. + bool dependsIndirectly( + const std::shared_ptr& A, + const std::shared_ptr& B); + + // Returns the AccessInfo + std::shared_ptr accessFor(const StmtPtr& A) const; + std::shared_ptr accessFor(const ExprPtr& A) const; + + // Returns all AccessInfos. + std::unordered_set> accessesWithin( + const StmtPtr& A) const; + // TODO: this will return only the AccessInfo for A. It's included for + // completeness but be aware it wont return accesses used in the computation + // of A. + std::unordered_set> accessesWithin( + const ExprPtr& A) const; + + // Accesses relating to input and output buffers. + std::shared_ptr input(const BufPtr& B) const; + std::shared_ptr output(const BufPtr& B) const; + + // Returns the full history of reads and writes. + const std::vector>& getHistory() const; + + // Dumps the dependency graph in DOT format. + void dumpDAG(const std::string& filename) const; + + private: + // Node visitors. + void visit(const StorePtr& v) override; + void visit(const LoadPtr& v) override; + void visit(const ForPtr& v) override; + void visit(const CondPtr& v) override; + void visit(const IfThenElsePtr& v) override; + void visit(const CompareSelectPtr& v) override; + void visit(const BlockPtr& v) override; + void visit(const LetPtr& v) override; + void visit(const AtomicAddPtr& v) override; + void visit(const AllocatePtr& v) override; + void visit(const FreePtr& v) override; + + using BoundRelationship = std::pair>; + + // An internal struct holding the accesses found within a scope Block. + struct Scope { + Scope(BlockPtr b, std::shared_ptr p) + : block(std::move(b)), parent(std::move(p)) {} + + BlockPtr block; + std::shared_ptr parent; + + std::unordered_map shadowedVarBounds; + std::unordered_set localVars; + + std::vector> accesses_; + + std::unordered_map> openWrites_; + }; + std::shared_ptr currentScope_; + + bool allowExecutionOrderAnalysis_{false}; + + std::unordered_multimap> stmtToAccess_; + std::unordered_multimap> exprToAccess_; + std::unordered_map>> + scopeToAccesses_; + + VarBoundMap knownVarBounds_; + + // Finds all accesses that are reads within the scope of v. + template + DependencySet getAllReadsWithin(const StmtOrExprPtr& v) { + DependencySet reads; + auto insertAllReads = [&](const auto& nodes) { + for (const auto& l : nodes) { + auto bound = exprToAccess_.equal_range(l); + for (auto it = bound.first; it != bound.second; ++it) { + if (it->second->isRead()) { + reads.insert(it->second); + } + } + } + }; + + // Look for and insert accesses belonging to all nodes that act like + // reads. + insertAllReads(NodeFinder::find(v)); + insertAllReads(NodeFinder::find(v)); + + return reads; + } + + // Finds all accesses that are writes within the scope of v. + // Writes cannot occur in Exprs, so this is a little simpler. + DependencySet getAllWritesWithin(const StmtPtr& v) { + DependencySet writes; + + // writes just Store currently. + auto stores = NodeFinder::find(v); + for (const auto& s : stores) { + auto bound = stmtToAccess_.equal_range(s); + for (auto it = bound.first; it != bound.second; ++it) { + if (it->second->isWrite()) { + writes.insert(it->second); + } + } + } + return writes; + } + + // Templated helpers to work on either Exprs or Stmts. + template + bool dependsDirectlyHelper(const StmtOrExprPtr& A, const StmtPtr& B) { + auto aReads = getAllReadsWithin(A); + auto bWrites = getAllWritesWithin(B); + + for (auto& read : aReads) { + for (auto& depPair : read->dependencies()) { + if (bWrites.count(depPair.second) != 0) { + return true; + } + } + } + + return false; + } + + template + bool dependsIndirectlyHelper(StmtOrExprPtr A, const StmtPtr& B) { + auto aReads = getAllReadsWithin(A); + auto bWrites = getAllWritesWithin(B); + + auto aDeps = getAllWriteDependencies(aReads); + + for (auto& dependency : aDeps) { + if (bWrites.count(dependency) != 0) { + return true; + } + } + + return false; + } + + DependencySet getAllWriteDependencies(const DependencySet& products); + + // Maps for inputs and outputs, since they aren't present directly in the IR. + std::unordered_map> inputs_; + std::unordered_map> outputs_; + std::unordered_map> intermediates_; + + // Inserts accesses for Buf's: specifically for inputs and outputs. + void insertBuffers( + std::unordered_map>& bufs, + AccessType type); + + // Update the write history with a new write, adding dependencies and closing + // any overlapped writes (if possible). + void updateWriteHistory( + std::list& writeHistory, + const std::shared_ptr& info, + size_t latestAccessToClose, + bool closeOverlapped = true, + bool insert = true); + + // Merge a child scope into a parent scope, adding dependencies for open + // writes in the parent to accesses in the child. + void mergeScope( + const std::shared_ptr& child, + const std::shared_ptr& parent, + bool closeOverlapped = true); + + // Binds symbolic vars in indices with the low and high bound for those vars. + std::vector getIndicesBounds(const std::vector& indices); + + size_t nextAccess_{0}; + StmtPtr lastStmt_{nullptr}; +}; + +} // namespace torch::jit::tensorexpr::analysis diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/conv2d.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/conv2d.h new file mode 100644 index 0000000000000000000000000000000000000000..9aa328d98b6db8cc567fcb84c32a5d768e8befb0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/conv2d.h @@ -0,0 +1,101 @@ +#pragma once + +#include +#include + +namespace torch::jit::tensorexpr { + +// An API to compute 2D depthwise convolutions with bias. +TORCH_API Tensor conv2d_depthwise( + BufHandle input, + BufHandle weight, + BufHandle bias, + int stride, + int pad, + int groups); + +// An API to compute 2D depthwise convolutions without bias. +TORCH_API Tensor conv2d_depthwise( + BufHandle input, + BufHandle weight, + int stride, + int pad, + int groups); + +TORCH_API Tensor conv2d_depthwise( + BufHandle input, + BufHandle weight, + BufHandle bias, + ExprHandle N, + ExprHandle C, + ExprHandle H, + ExprHandle W, + ExprHandle K, + ExprHandle CperG, + ExprHandle R, + ExprHandle S, + ExprHandle stride, + ExprHandle pad, + ExprHandle groups); + +TORCH_API Tensor conv2d_depthwise( + BufHandle input, + BufHandle weight, + ExprHandle N, + ExprHandle C, + ExprHandle H, + ExprHandle W, + ExprHandle K, + ExprHandle CperG, + ExprHandle R, + ExprHandle S, + ExprHandle stride, + ExprHandle pad, + ExprHandle groups); + +bool conv2dIsSupported( + const TensorInfo& input, + const TensorInfo& weight, + const TensorInfo& bias, + const std::vector& stride, + const std::vector& pad, + const std::vector& dilation, + int64_t groups); +bool mkldnnPrepackedConvIsSupported( + const TensorInfo& input, + const TensorInfo& weight, + const std::vector& stride, + const std::vector& pad, + const std::vector& dilation, + int64_t groups); +Tensor computeConv2d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeConv1d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computePrepackedConv2dClampRun( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computePrepackedLinearClampRun( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeMkldnnPrepackedConvRun( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/matmul.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..d572a1c396c0e3637bf981983f8e27d36474a932 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/matmul.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace torch::jit::tensorexpr { + +Tensor computeMatmul( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeAddMM( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/misc.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/misc.h new file mode 100644 index 0000000000000000000000000000000000000000..cb257eb3b7e03e25ddc96d4ea749c848ac4666e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/misc.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include + +namespace torch::jit::tensorexpr { + +struct TensorInfo { + std::vector dims; + c10::ScalarType dtype; +}; +std::optional getTensorInfo(const BufHandle& b); + +int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); + +// Convert boolean to integer, if needed. +ExprHandle boolToInteger(const ExprHandle& x); +ExprHandle promoteToDtype(ExprHandle e, ScalarType dt); +void promoteInputs( + std::vector& inputs, + const int typeConstraints = kAllTypes); +ExprHandle promoteIntegerToDefaultType(const ExprHandle& e); +ExprHandle promoteHalfToFloat(const ExprHandle& e); +ExprHandle demoteOutput( + const ExprHandle& e, + const std::optional type); + +std::vector broadcastShapes( + std::vector> shapes); +std::vector broadcastShapes( + const std::vector& a, + const std::vector& b); + +std::vector valueShape(const ArgValue& v); +ExprHandle tensorOrConstant( + const ArgValue& v, + const std::vector& axes); +ExprHandle scalarOrConstant(const ArgValue& v); +ExprHandle broadcast(const BufHandle& b, const std::vector& axes); +ExprHandle constant(const ArgValue& v); + +ExprHandle clamp( + const ExprHandle& cmin, + const ExprHandle& cmax, + const ExprHandle& input); + +Tensor computeChunk( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeTranspose( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeExpand( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeReshape( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeFlatten( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeCatWoConditionals( + const std::vector& inputs, + const std::vector& outputShape); +Tensor computeCat( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeEmbedding( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/norm.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/norm.h new file mode 100644 index 0000000000000000000000000000000000000000..e531943237b098582e2303cf6e1f734118826b77 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/norm.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace torch::jit::tensorexpr { + +Tensor computeBatchNorm( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/operators.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/operators.h new file mode 100644 index 0000000000000000000000000000000000000000..6298a6480149b9db1536ea408094e1d259c2605f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/operators.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/pointwise.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/pointwise.h new file mode 100644 index 0000000000000000000000000000000000000000..8f8f6240d19848dd28f73d91aff00b2b6db4e8b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/pointwise.h @@ -0,0 +1,82 @@ +#pragma once + +#include + +namespace torch::jit::tensorexpr { + +TORCH_API Tensor computeSign( + const std::vector& inputs, + const std::vector& outputShape, + const std::optional>& outputStrides = std::nullopt); + +Tensor computeOneOperand( + const std::string& name, + const std::vector& inputValues, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + const std::function& innerExpr, + const int checkParamTypes = kAllTypes); +Tensor computeTwoOperand( + const std::string& name, + const std::vector& inputValues, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + const std::function& + innerExpr); +Tensor computeTwoOperandWithAlpha( + const std::string& name, + const std::vector& inputValues, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + const std::function& + innerExpr); +Tensor computeConditionWithTwoOperand( + const std::string& name, + const std::vector& inputValues, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + const std::function< + ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& + innerExpr); +Tensor computeThreeOperand( + const std::string& name, + const std::vector& inputValues, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + const std::function< + ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& + innerExpr, + bool promote_inputs = true); +Tensor computeFourOperand( + const std::string& name, + const std::vector& inputValues, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + const std::function& innerExpr); +Tensor computeNoop( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +Tensor computeScalar( + const std::string& name, + const std::vector& inputValues, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + const std::function& + innerExpr); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/quantization.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/quantization.h new file mode 100644 index 0000000000000000000000000000000000000000..51bdbe730a6a0b4f55954f959e8f1be2a82d1dc2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/quantization.h @@ -0,0 +1,156 @@ +#pragma once + +#include + +namespace torch::jit::tensorexpr { + +TORCH_API ExprHandle quantizePerTensorQParamFromArg(ArgValue arg); + +TORCH_API double immQScale(const BufHandle& qx); + +TORCH_API int64_t immQZero(const BufHandle& qx); + +TORCH_API ScalarType immQDType(const BufHandle& qx); + +TORCH_API bool isQuantized(const BufHandle& qx); + +TORCH_API Tensor computeQuantizePerTensor( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizePerTensorExternalCall( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedConv1d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedConv2dPrepack( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedConv1d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedConv2d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedConv2dRelu( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedLinear( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedLinearRelu( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedAdd( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +Tensor computeQuantizedAddExternalCall( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedMul( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedMulScalar( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedCat( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedRelu( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeDequantize( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeDequantizeExternalCall( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeUpsampleNearest2d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeUpsampleNearest2dExternalCall( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +TORCH_API Tensor computeQuantizedSigmoidExternalCall( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device); +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/reduction.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..615d75c397c921f47408e016a33541b278ff2e2c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/reduction.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace torch::jit::tensorexpr { + +TORCH_API Tensor computeSum( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +TORCH_API Tensor computeMean( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +TORCH_API Tensor computeAdaptiveAvgPool2d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); +Tensor computeMax( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/softmax.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..f2a5698673cf3d1e4aba1e6899ef9254cb8f20d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/operators/softmax.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace torch::jit::tensorexpr { + +Tensor computeSoftmax( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + bool log_softmax); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/reduction.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..c65cf43be7fba00ff9cc974c7316d468deb3ad12 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/reduction.h @@ -0,0 +1,306 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::jit::tensorexpr { + +using ParameterList = const std::vector; +using ReduceInteraction = std::function; + +// A Reducer is a user interface describing a particular reduction +// operation. It has three components: An initialization value, a way of +// interacting each value with the accumulation, and a method for obtaining the +// current value to be reduced. It is materialized into a ReduceOp when loop +// variables are known. +class TORCH_API Reducer { + public: + Reducer(ExprHandle init, ReduceInteraction& interaction) + : init_(init.node()), interaction_(interaction) {} + + template + Reducer(ExprHandle init, RI interaction) + : init_(init.node()), interaction_(std::move(interaction)) {} + + ExprPtr initializer() const { + return init_; + } + + ExprHandle operator()( + const BufHandle& result_buf, + ExprHandle body, + const std::vector& output, + const std::vector& inner) const; + + ReduceOpPtr operator()( + const BufPtr& result_buf, + ExprPtr body, + const std::vector& output, + const std::vector& inner) const; + + ExprHandle operator()( + const BufHandle& result_buf, + BufHandle acc_buf, + const ExprHandle& body, + const std::vector& output, + const std::vector& inner) const; + + // Polymorphic handling of Body functions with a variety of parameters. + static ExprHandle getReduceBody( + const std::function& func, + const std::vector& vars) { + return func(vars); + } + + static ExprHandle getReduceBody( + const std::function& func, + const std::vector& vars) { + if (vars.size() != 1) { + throw malformed_input("mismatch between reduce body and arg size (1)"); + } + + return func(vars[0]); + } + + static ExprHandle getReduceBody( + const std::function& func, + const std::vector& vars) { + if (vars.size() != 2) { + throw malformed_input("mismatch between reduce body and arg size (2)"); + } + return func(vars[0], vars[1]); + } + + static ExprHandle getReduceBody( + const std::function< + ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& + func, + const std::vector& vars) { + if (vars.size() != 3) { + throw malformed_input("mismatch between reduce body and arg size (3)"); + } + return func(vars[0], vars[1], vars[2]); + } + + static ExprHandle getReduceBody( + const std::function& func, + const std::vector& vars) { + if (vars.size() != 4) { + throw malformed_input("mismatch between reduce body and arg size (4)"); + } + return func(vars[0], vars[1], vars[2], vars[3]); + } + + // Completes the reduction operator by applying the interaction function to + // the accumulation and the body expression. + static ExprPtr complete( + const BufPtr& accumulator, + const ReduceInteraction& interaction, + ExprHandle body, + const std::vector& output_args, + const std::vector& reduce_args) { + ExprHandle accum = + ExprHandle(alloc(body.dtype(), accumulator, output_args)); + auto e = interaction(std::move(accum), std::move(body)); + return e.node(); + } + static ExprHandle complete( + const BufHandle& accumulator, + const ReduceInteraction& interaction, + ExprHandle body, + const std::vector& output_args, + const std::vector& reduce_args) { + ExprHandle accum = Load::make(body.dtype(), accumulator, output_args); + auto e = interaction(std::move(accum), std::move(body)); + return e; + } + + private: + ExprPtr init_; + ReduceInteraction interaction_; +}; + +// An expression representing a Reduction operation (e.g. Sum, Max) broken into +// it's component parts: initialization, accumulation var, acquisition of value +// to be reduced and interaction. +// +// This is intended to be expanded in the loopnest and not make it to codegen. +class TORCH_API ReduceOp : public ExprNode { + public: + ReduceOp( + const ExprPtr& body, + std::vector reduce_args, + Reducer reducer) + : ExprNodeBase(body->dtype()), + body_(body), + reduce_args_(std::move(reduce_args)), + reducer_(std::move(reducer)) { + result_buf_ = nullptr; + acc_buf_ = nullptr; + ri_operand_ = nullptr; + } + + ReduceOp( + const ExprPtr& body, + std::vector reduce_args, + BufPtr result_buf, + BufPtr acc_buf, + ExprPtr ri_operand, + Reducer reducer) + : ExprNodeBase(body->dtype()), + body_(body), + reduce_args_(std::move(reduce_args)), + result_buf_(std::move(result_buf)), + acc_buf_(std::move(acc_buf)), + ri_operand_(std::move(ri_operand)), + reducer_(std::move(reducer)) {} + + static ExprHandle make( + ExprHandle body, + const std::vector& reduce_args, + const Reducer& reducer); + + static ExprHandle make( + ExprHandle body, + const std::vector& reduce_args, + BufHandle result_buf, + BufHandle acc_buf, + ExprHandle ri_operand, + const Reducer& reducer); + + // return the body expression which obtains the value to be reduced. + ExprPtr body() const { + return body_; + } + + // Returns the original Reducer factory that can create ReduceOps. + const Reducer& reducer() const { + return reducer_; + } + + // returns variables associated with the axes of reduction. + const std::vector& reduce_args() const { + return reduce_args_; + } + + void setAccBuf(BufHandle acc_buf) { + acc_buf_ = acc_buf.node(); + } + BufPtr getAccBuf() { + return acc_buf_; + } + + void setResultBuf(BufHandle buf) { + result_buf_ = buf.node(); + } + BufPtr getResultBuf() { + return result_buf_; + } + + void setRiOperand(ExprHandle ri_operand) { + ri_operand_ = ri_operand.node(); + } + ExprPtr getRiOperand() { + return ri_operand_; + } + + private: + // body_ = reducer_->interaction_(result_buf_, ri_operand_) + ExprPtr body_; + std::vector reduce_args_; + + BufPtr result_buf_; + BufPtr acc_buf_; + ExprPtr ri_operand_; + + const Reducer reducer_; +}; + +class Sum : public Reducer { + public: + Sum() + : Reducer(ExprHandle(0), [](const ExprHandle& a, const ExprHandle& b) { + return a + b; + }) {} +}; + +inline ExprHandle maximumVal(ScalarType type) { + switch (type) { +#define MAX_BY_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + return ExprHandle(std::numeric_limits::max()); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE) +#undef MAX_BY_TYPE_CASE + default: + throw unsupported_dtype(); + } + return ExprHandle(); +} + +inline ExprHandle minimumVal(ScalarType type) { + switch (type) { +#define MAX_BY_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + return ExprHandle(std::numeric_limits::min()); + AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, MAX_BY_TYPE_CASE) +#undef MAX_BY_TYPE_CASE + default: + throw unsupported_dtype(); + } +} + +class Maximum : public Reducer { + public: + // TODO possible to remove this arg by deferring the init value until we + // know the dtype of the body. + Maximum(Dtype dtype) + : Reducer( + minimumVal(dtype.scalar_type()), + [](const ExprHandle& a, const ExprHandle& b) { + return Max::make(a, b, true); + }) {} + Maximum(ExprHandle initializer) + : Reducer( + std::move(initializer), + [](const ExprHandle& a, const ExprHandle& b) { + return Max::make(a, b, true); + }) {} +}; + +class Minimum : public Reducer { + public: + Minimum(Dtype dtype) + : Reducer( + maximumVal(dtype.scalar_type()), + [](const ExprHandle& a, const ExprHandle& b) { + return Min::make(a, b, true); + }) {} + Minimum(const ExprHandle& initializer) + : Reducer(initializer, [](const ExprHandle& a, const ExprHandle& b) { + return Min::make(a, b, true); + }) {} +}; + +class ReductionExpander : public IRMutator { + public: + StmtPtr expand(const StmtPtr& s) { + return s->accept_mutator(this); + } + + ExprPtr mutate(const ReduceOpPtr& v) override { + return v->body(); + } +}; + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/registerizer.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/registerizer.h new file mode 100644 index 0000000000000000000000000000000000000000..752537bb089953df9abcbb86568ee8b2a69b8f57 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/registerizer.h @@ -0,0 +1,426 @@ +#pragma once +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace torch::jit::tensorexpr { +namespace registerizer { + +/* The Registerizer performs scalar replacement by looking for common Stores and +Loads to a single item in a buffer and replacing them with a local temporary +scalar which is cheaper to write. + +For example it can replace: + +{ + A[0] = 0; + for(const auto x : c10::irange(10)) { + A[0] = (A[0]) + x; + } +} + +with: + +{ + int A_ = 0; + for(const auto x : c10::irange(10)) { + A_ = x + A_; + } + A[0] = A_; +} + +This is particularly useful on GPUs when parallelizing, since after replacing +loops with metavars we have a lot of accesses like this. */ + +class Scope; + +/* Holds analysis information about accesses to a specific range of a + buffer, including the number of loads and stores and the lowest common parent + Block. + */ +class AccessInfo { + public: + AccessInfo() = default; + AccessInfo( + SimplifierHashType h, + BufPtr b, + std::vector i, + size_t accessOrder) + : hash_(h), + buf_(std::move(b)), + indices_(std::move(i)), + store_cost_(alloc(0)), + load_cost_(alloc(0)), + accessOrder_(accessOrder) {} + + // Adds a Store to this access, which is in the provided scope. + void addStore(const StorePtr& store, const std::shared_ptr& scope); + + // Adds a Load to this access, which occurs in the usage Stmt in the provided + // scope. + void addLoad( + const LoadPtr& load, + const std::shared_ptr& scope, + const StmtPtr& usage); + + // Merge another AccessInfo into this one. + void merge(const std::shared_ptr& other); + + // Returns true if the other AccessInfo's bounds may overlap this one. + bool overlaps(const std::shared_ptr& other); + + // Returns true if the indices of this access depend on the provided Var. + bool dependsOnVar(const VarPtr& v); + + // Clone this AccessInfo, and set this as the new accesses' hiddenAccess. + static std::shared_ptr cloneWithHiddenInfo( + const std::shared_ptr& orig); + + // print for debugging. + void print() const; + + SimplifierHashType hash() const { + return hash_; + } + + BufPtr buf() const { + return buf_; + } + + const std::vector& indices() const { + return indices_; + } + + BlockPtr block() const { + return block_; + } + + void setEnclosingBlock(BlockPtr b) { + block_ = std::move(b); + } + + StmtPtr first_usage() const { + return first_usage_; + } + StmtPtr last_usage() const { + return last_usage_; + } + + void setUsageMarks(StmtPtr first, StmtPtr last) { + first_usage_ = std::move(first); + last_usage_ = std::move(last); + } + + bool firstUsageOverlapped() const { + return firstUsageOverlapped_; + } + + ExprPtr store_cost() const { + return store_cost_; + } + + ExprPtr load_cost() const { + return load_cost_; + } + + const std::vector& stores() const { + return stores_; + } + + const std::vector& loads() const { + return loads_; + } + + void hoistCosts(const ExprPtr& extent) { + store_cost_ = IRSimplifier::simplify(alloc(store_cost_, extent)); + load_cost_ = IRSimplifier::simplify(alloc(load_cost_, extent)); + } + + size_t conditionId() const { + return conditionId_; + } + + void setConditionId(size_t c) { + conditionId_ = c; + } + + size_t accessOrder() const { + return accessOrder_; + } + + std::shared_ptr hiddenAccess() const { + return hiddenAccess_; + } + + // Holds state relating to the scalar variable we will insert to replace some + // number of loads and stores. + struct ScalarReplacement { + VarPtr var{nullptr}; + BufPtr var_wrapper{nullptr}; + LetPtr initializer{nullptr}; + }; + + ScalarReplacement& replacement() { + return replacement_; + } + + private: + SimplifierHashType hash_; + BufPtr buf_; + std::vector indices_; + BlockPtr block_{nullptr}; + + StmtPtr first_usage_{nullptr}; + StmtPtr last_usage_{nullptr}; + + // Whether or not this access is overlapped in the first Stmt it appears. This + // means we cannot use it's first Store as the initializer. + bool firstUsageOverlapped_{false}; + + // The cost in real ops that this access represents, to enable + // filtering accesses that wont save any loads or stores. + ExprPtr store_cost_; + ExprPtr load_cost_; + + // The actual Stores and Loads which represent this access. + // Be careful with these, any mutator will invalidate these pointers. + std::vector stores_; + std::vector loads_; + + // An identifier representing the conditional block, if any, this access + // depends on. + size_t conditionId_{0}; + + // An identifier representing the order this access was first encountered, for + // sorting returned results. + size_t accessOrder_{0}; + + // Sometimes when traversing the tree we need to record what would happen if + // we hoisted an access, but sometimes it doesn't work out. This lets us + // "undo" some mutation and return to the internal hidden AccessInfo. + // It will be removed after any further additions to this AccessInfo. + std::shared_ptr hiddenAccess_; + + ScalarReplacement replacement_; +}; + +using AccessHashMap = + std::unordered_map>; + +// Represents a scope block and holds all accesses contained within it. +class Scope { + public: + Scope(BlockPtr b, std::shared_ptr parent, size_t conditionId = 0) + : block_(std::move(b)), + parent_(std::move(parent)), + conditionId_(conditionId) {} + + AccessHashMap& getAccessMapByBuf(const BufPtr& b); + + std::unordered_map& openAccesses() { + return openAccesses_; + } + + std::vector>& closedAccesses() { + return closedAccesses_; + } + + BlockPtr block() const { + return block_; + } + + std::shared_ptr parent() const { + return parent_; + } + + size_t conditionId() const { + return conditionId_; + } + + const std::unordered_set& localVars() const { + return localVars_; + } + void addLocalVar(VarPtr v) { + localVars_.insert(std::move(v)); + } + + void closeAccess(const std::shared_ptr& info); + + void filterClosed(); + + private: + // Map of map to access, narrowing by Buf then by hash(Buf+Indices). + // This allows us to find a candidate access easily, and also check for + // overlap with other accesses to the same buf. Buf -> + // Hash -> + // Access + std::unordered_map openAccesses_; + std::vector> closedAccesses_; + + // The Block object this scope represents. + BlockPtr block_; + + // The enclosing scope object. + std::shared_ptr parent_; + + // An identifier representing the condition block this scope depends on. + size_t conditionId_; + + // A set of variables local to this scope (e.g. loop vars). + std::unordered_set localVars_; +}; + +/* Analyzes the graph and collects accesses to the same symbolic tensor element + * which can be replaced by a single local scalar. + * + * This works by recursively walking the tree in postfix order, building sets of + * accesses to the same symbolic element by scope and then merging lower scopes + * into their enclosing scope. + * + * It is safe to move two accesses of the same Tensor element to a local scalar + * Var if between all usages of the element there are no other Loads or Stores + * that may refer to it. In the comments I refer to this as overlapping the + * access, or "cutting" the existing AccessInfo. In the case where a candidate + * for registerization is cut, it may be possible to finalize the access early + * by writing it back to the Tensor and then create a new scalar variable after + * the overlapping access is complete. We will attempt to do this when it saves + * memory accesses. + * + * There are a few cases that make this more challenging: + * + * - For: Loops change the number of real usages of a buffer by the loop + * extent, but only if we can pull the definition and finalization of the scalar + * variable out of the loop block. + * + * - Cond: Conditions complicate lifting scalars out of internal scopes. + * Generally we cannot lift an access outside of a conditional scope unless + * there is already a reference to that same access at the higher scope, since + * we don't know if the condition was guarding an array access not safe at the + * higher scope. In the comments I refer to this as the condition "hiding" the + * access, and the outer access "unhiding" it. + * + * - IfThenElse: Same situation as Cond, except since IfThenElse is an Expr + * rather than a Stmt we cannot insert the scalar definition or finalizer + * within the conditional scope. Accesses inside an IfThenElse can be safely + * combined with external accesses but cannot exist completely within. + * + * - Let: Accesses dependent on local variables via Let Stmts, or loop vars, + * cannot be raised outside of the scope of the dependent var. + */ +class TORCH_API RegisterizerAnalysis : public IRVisitor { + public: + RegisterizerAnalysis() + : currentScope_(std::make_shared(nullptr, nullptr, 0)) {} + ~RegisterizerAnalysis() override = default; + + void visit(const ForPtr& v) override; + + void visit(const CondPtr& v) override; + + void visit(const BlockPtr& v) override; + + void visit(const StorePtr& v) override; + + void visit(const LoadPtr& v) override; + + void visit(const IfThenElsePtr& v) override; + + void visit(const LetPtr& v) override; + +#define STMT_ON_STACK(Op) \ + void visit(const Op##Ptr& v) override { \ + stmtStack_.push_front(v); \ + IRVisitor::visit(v); \ + stmtStack_.pop_front(); \ + } + + STMT_ON_STACK(AtomicAdd) + STMT_ON_STACK(Allocate) + STMT_ON_STACK(Free) + +#undef STMT_ON_STACK + + std::vector> getCandidates(); + + private: + void mergeCurrentScopeIntoParent(); + void mergeHiddenScope(bool allowClosed); + void closeAccessIntoScope( + const std::shared_ptr& info, + const std::shared_ptr& scope); + + std::unordered_set exprConditionals_; + + // A stack of enclosing Stmts for tracking the usage Stmt of Loads. + std::deque stmtStack_; + + // The current scope being analyzed. + std::shared_ptr currentScope_; + + HashProvider hasher_; + + size_t conditionId_{0}; + size_t accessOrder_{0}; +}; + +/* Replaces each registerizable access with a Scalar variable, including + * definition, initializer and finalizer. + */ +class TORCH_API RegisterizerReplacer : public IRMutator { + public: + RegisterizerReplacer(std::vector>& vec) + : infoSet_(vec) { + buildReplacements(); + } + + ExprPtr mutate(const LoadPtr& v) override; + + StmtPtr mutate(const StorePtr& v) override; + + StmtPtr mutate(const BlockPtr& v) override; + + private: + struct ReplacerScope { + std::unordered_map>> + initializerPoints_; + std::unordered_map>> + finalizePoints_; + }; + + // Creates the various ReplacerScope objects and builds internal maps. + void buildReplacements(); + + // State relating to the accesses yet to be replaced. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + std::vector>& infoSet_; + std::unordered_map> storeToAccess_; + std::unordered_map> loadToAccess_; + std::unordered_map parentToAccesses_; + + // Holds the set of Stores that should be pulled into an initializer, so they + // can be eliminated. + std::set eliminatedIntializers_; + + // Tracks the number of times we've seen each buffer, so we can name the + // scalar Vars appropriately. + std::unordered_map bufferAccessCounts_; + unsigned int getBufferAccessCount(const BufPtr& b) { + return ++bufferAccessCounts_[b]; + } +}; +} // namespace registerizer + +// Apply scalar replacement to all accesses in s. +// To produce safe code, this must occur after handling parallelized axes and +// atomics. +TORCH_API StmtPtr registerize(StmtPtr s); + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/stmt.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/stmt.h new file mode 100644 index 0000000000000000000000000000000000000000..5cdbe7de5217409ef6ac58038e4b6279a5bcf4ac --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/stmt.h @@ -0,0 +1,1012 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace torch::jit::tensorexpr { + +// The common base between all statement node. +class TORCH_API Stmt : public std::enable_shared_from_this { + public: + Stmt() = default; + virtual ~Stmt() = default; + virtual void accept(IRVisitor* visitor) = 0; + virtual StmtPtr accept_mutator(IRMutator* mutator) = 0; + + StmtPtr get_parent() const { + return parent_ ? parent_->getptr() : nullptr; + } + + /* + * Make a deep copy of the given statement. + * + * All statements and expressions used in children of the statement are + * cloned. Note that the variables are not deep-copied since they are + * immutable. + */ + static StmtPtr clone(const StmtPtr& s); + + protected: + static void set_parent(const StmtPtr& s, Stmt* new_parent) { + s->parent_ = new_parent; + } + std::shared_ptr getptr() { + return shared_from_this(); + } + + private: + Stmt* parent_ = nullptr; +}; + +template +class StmtNode : public Stmt { + public: + using StmtNodeBase = StmtNode; + void accept(IRVisitor* visitor) override { + visitor->visit(static_to(getptr())); + } + StmtPtr accept_mutator(IRMutator* mutator) override; + friend Op; + + private: + StmtNode() = default; +}; + +template +StmtPtr StmtNode::accept_mutator(IRMutator* mutator) { + return mutator->mutate(static_to(getptr())); +} + +// Concrete Stmt classes +class TORCH_API Block : public StmtNode { + public: + static BlockPtr make(const std::vector& stmts) { + std::vector valid_stmts; + for (auto& stmt : stmts) { + if (!stmt) { + continue; + } + valid_stmts.push_back(stmt); + } + if (valid_stmts.empty()) { + return nullptr; + } + return alloc(valid_stmts); + } + + size_t nstmts() const { + return stmts_.size(); + } + bool empty() const { + return stmts_.empty(); + } + + void prepend_stmt(const StmtPtr& s) { + if (s->get_parent()) { + throw malformed_input("Block prepend Stmt with existing parent", s); + } + + stmts_.push_front(s); + set_parent(s, this); + } + void append_stmt(const StmtPtr& s) { + if (s->get_parent()) { + throw malformed_input("Block append Stmt with existing parent", s); + } + + stmts_.push_back(s); + set_parent(s, this); + } + + void insert_stmt_before(const StmtPtr& s, const StmtPtr& before) { + if (s->get_parent()) { + throw malformed_input("Block append Stmt with existing parent", s); + } + + auto pos = std::find(stmts_.begin(), stmts_.end(), before); + if (pos == stmts_.end()) { + throw malformed_input( + "Inserting after statement that is not in block", s); + } + + stmts_.insert(pos, s); + set_parent(s, this); + } + + void insert_stmt_after(const StmtPtr& s, const StmtPtr& after) { + if (s->get_parent()) { + throw malformed_input("Block append Stmt with existing parent", s); + } + + auto pos = std::find(stmts_.begin(), stmts_.end(), after); + if (pos == stmts_.end()) { + throw malformed_input( + "Inserting after statement that is not in block", s); + } + + ++pos; + + stmts_.insert(pos, s); + set_parent(s, this); + } + + bool replace_stmt(const StmtPtr& old_stmt, const StmtPtr& new_stmt) { + if (new_stmt->get_parent()) { + throw malformed_input( + "Block replace Stmt with existing parent", new_stmt); + } + + auto pos = std::find(stmts_.begin(), stmts_.end(), old_stmt); + if (pos == stmts_.end()) { + return false; + } + stmts_.insert(pos, new_stmt); + stmts_.erase(pos); + set_parent(old_stmt, nullptr); + set_parent(new_stmt, this); + return true; + } + + // Creates a new block by cloning `this` block and replacing the given + // statement with a new statement. Note that `old_stmt` refers to a statement + // in `this` block. If the `old_stmt` is not found, it will return `nullptr`. + BlockPtr clone_and_replace(const StmtPtr& old_stmt, const StmtPtr& new_stmt) { + if (new_stmt->get_parent()) { + throw malformed_input( + "Block replace Stmt with existing parent", new_stmt); + } + + std::vector stmts(stmts_.begin(), stmts_.end()); + std::vector cloned_stmts(stmts.size()); + bool found = false; + for (int i = 0; i < static_cast(stmts.size()); ++i) { + if (stmts[i] == old_stmt) { + found = true; + cloned_stmts[i] = new_stmt; + } else { + cloned_stmts[i] = Stmt::clone(stmts[i]); + } + } + if (!found) { + return nullptr; + } + return alloc(cloned_stmts); + } + + bool remove_stmt(const StmtPtr& stmt) { + auto pos = std::find(stmts_.begin(), stmts_.end(), stmt); + if (pos == stmts_.end()) { + return false; + } + + set_parent(stmt, nullptr); + stmts_.erase(pos); + return true; + } + + std::list stmts() const { + return stmts_; + } + + void clear() { + for (const auto& s : stmts_) { + set_parent(s, nullptr); + } + stmts_.clear(); + } + + void set_stmts(const std::vector& stmts) { + clear(); + init(stmts); + } + + explicit Block(const std::vector& stmts) { + init(stmts); + } + + typedef std::list::iterator iterator; + typedef std::list::const_iterator const_iterator; + + iterator begin() { + return stmts_.begin(); + } + + const_iterator begin() const { + return stmts_.begin(); + } + + iterator end() { + return stmts_.end(); + } + + const_iterator end() const { + return stmts_.end(); + } + + StmtPtr front() { + return stmts_.front(); + } + + StmtPtr front() const { + return stmts_.front(); + } + + StmtPtr back() { + return stmts_.back(); + } + + StmtPtr back() const { + return stmts_.back(); + } + + void splice(Block::iterator it, const BlockPtr& other) { + for (const StmtPtr& s : *other) { + set_parent(s, this); + } + + stmts_.splice(it, other->stmts_); + } + + static BlockPtr getSharedParent(StmtPtr p1, StmtPtr p2) { + std::unordered_set enclosing; + + StmtPtr p1_p = std::move(p1); + while (p1_p) { + if (BlockPtr b = to(p1_p)) { + enclosing.insert(b); + } + p1_p = p1_p->get_parent(); + } + + StmtPtr p2_p = std::move(p2); + while (p2_p) { + if (BlockPtr b = to(p2_p)) { + if (enclosing.count(b) != 0) { + return b; + } + } + p2_p = p2_p->get_parent(); + } + + return nullptr; + } + + // returns the immediate child containing statement s. + StmtPtr getEnclosedRoot(StmtPtr s) const { + while (s && s->get_parent().get() != this) { + s = s->get_parent(); + } + return s; + } + + private: + std::list stmts_; + + void init(const std::vector& stmts) { + for (const StmtPtr& s : stmts) { + if (!s) { + continue; + } + if (!s->get_parent()) { + // If we get here, it's a bug, but we cannot throw an error from a + // constructor. But IR verifier would catch this. + set_parent(s, this); + } + + stmts_.push_back(s); + } + } +}; + +class TORCH_API Store : public StmtNode { + public: + VarPtr base_handle() const { + return buf_->base_handle(); + } + std::vector indices() const { + return indices_; + } + ExprPtr flat_index() const { + TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened."); + return indices_[0]; + } + ExprPtr value() const { + return value_; + } + BufPtr buf() const { + return buf_; + } + + void set_buf(BufPtr buf) { + buf_ = std::move(buf); + } + + void set_indices(std::vector indices) { + indices_ = std::move(indices); + } + + void set_value(ExprPtr value) { + value_ = std::move(value); + } + + static StorePtr make( + const BufHandle& buf, + const std::vector& indices, + const ExprHandle& value); + + Store(BufPtr buf, std::vector indices, ExprPtr value); + + private: + BufPtr buf_; + std::vector indices_; + ExprPtr value_; +}; + +// Allocate a buffer of given shapes and dtypes and bind it with the given +// buffer var. The life span is at most through the current program, until it is +// explicitly freed. An unfreed memory is likely considered an error. +class TORCH_API Allocate : public StmtNode { + public: + static AllocatePtr make(const BufHandle& buf_handle) { + return alloc(buf_handle.node()); + } + + VarPtr buffer_var() const { + return buf_->base_handle(); + } + + Dtype dtype() const { + return buf_->dtype(); + } + + const std::vector dims() const { + return buf_->dims(); + } + + BufPtr buf() const { + return buf_; + } + + void set_buf(BufPtr buf) { + buf_ = std::move(buf); + } + + explicit Allocate(BufPtr buf) : buf_(std::move(buf)) {} + + private: + BufPtr buf_; + // TODO: add memory types. +}; + +// PlacementAllocate is a variation of the Allocate operator in NNC IR. It does +// not allocate memory but reuse the memory of another buffer for the given +// buffer. +class TORCH_API PlacementAllocate : public StmtNode { + public: + static PlacementAllocatePtr make( + const BufHandle& buf_handle, + const BufHandle& buf_handle_to_reuse) { + return alloc( + buf_handle.node(), buf_handle_to_reuse.node()); + } + + BufPtr buf() const { + return buf_; + } + + BufPtr buf_to_reuse() const { + return buf_to_reuse_; + } + + void set_buf(BufPtr buf) { + buf_ = std::move(buf); + } + + void set_buf_to_reuse(BufPtr buf) { + buf_to_reuse_ = std::move(buf); + } + + explicit PlacementAllocate(BufPtr buf, BufPtr buf_to_reuse) + : buf_(std::move(buf)), buf_to_reuse_(std::move(buf_to_reuse)) {} + + private: + BufPtr buf_; + BufPtr buf_to_reuse_; +}; + +// Free the specific buffer. It is an error. +class TORCH_API Free : public StmtNode { + public: + static FreePtr make(const BufHandle& buf_handle) { + return alloc(buf_handle.node()); + } + + VarPtr buffer_var() const { + return buf_->base_handle(); + } + + BufPtr buf() const { + return buf_; + } + + void set_buf(BufPtr buf) { + buf_ = std::move(buf); + } + + explicit Free(BufPtr buf) : buf_(std::move(buf)) {} + + private: + BufPtr buf_; +}; + +class TORCH_API FreeExt : public StmtNode { + public: + static FreeExtPtr make(const std::vector& bufs); + + std::vector bufs() const { + return bufs_; + } + + void set_bufs(std::vector bufs) { + bufs_ = std::move(bufs); + } + + explicit FreeExt(std::vector bufs) : bufs_(std::move(bufs)) {} + + private: + std::vector bufs_; +}; + +class TORCH_API Let : public StmtNode { + public: + static LetPtr make(const VarHandle& var, const ExprHandle& val) { + return alloc(var.node(), val.node()); + } + + Let(VarPtr var, ExprPtr val) : var_(std::move(var)), val_(std::move(val)) {} + + VarPtr var() const { + return var_; + } + + ExprPtr value() const { + return val_; + } + + void set_var(VarPtr var) { + var_ = std::move(var); + } + + void set_val(ExprPtr val) { + val_ = std::move(val); + } + + private: + VarPtr var_; + ExprPtr val_; +}; + +class TORCH_API Cond : public StmtNode { + public: + static CondPtr make( + const ExprHandle& condition, + const StmtPtr& true_stmt, + const StmtPtr& false_stmt) { + return alloc(condition.node(), true_stmt, false_stmt); + } + + ExprPtr condition() const { + return condition_; + } + + BlockPtr true_stmt() const { + return true_stmt_; + } + + BlockPtr false_stmt() const { + return false_stmt_; + } + + void set_condition(ExprPtr condition) { + condition_ = std::move(condition); + } + + void set_true_stmt(StmtPtr true_stmt) { + if (true_stmt) { + BlockPtr b = to(true_stmt); + if (!b) { + b = alloc(std::vector({std::move(true_stmt)})); + } + true_stmt_ = b; + set_parent(true_stmt_, this); + } + } + + void set_false_stmt(StmtPtr false_stmt) { + if (false_stmt) { + BlockPtr b = to(false_stmt); + if (!b) { + b = alloc(std::vector({std::move(false_stmt)})); + } + false_stmt_ = b; + set_parent(false_stmt_, this); + } + } + + Cond(ExprPtr condition, StmtPtr true_stmt, StmtPtr false_stmt) + : condition_(std::move(condition)) { + set_true_stmt(std::move(true_stmt)); + set_false_stmt(std::move(false_stmt)); + } + + CondPtr cloneWithNewBodies( + const StmtPtr& true_stmt, + const StmtPtr& false_stmt) { + return alloc(condition_, true_stmt, false_stmt); + } + + CondPtr cloneWithNewBody(const StmtPtr& true_stmt) { + return alloc(condition_, true_stmt, nullptr); + } + + private: + ExprPtr condition_; + BlockPtr true_stmt_ = nullptr; + BlockPtr false_stmt_ = nullptr; +}; + +class TORCH_API LoopOptions { + public: + enum { + IDX_UNSET = -1, + IDX_X = 0, + IDX_Y = 1, + IDX_Z = 2, + IDX_W = 3, + IDX_MAX = IDX_W, + }; + // GPU Block Index + bool is_gpu_block_index() const { + return gpu_block_index_ != IDX_UNSET; + } + + int gpu_block_index() const { + return gpu_block_index_; + } + + std::string gpu_block_index_str() const { + if (!is_gpu_block_index()) { + throw malformed_input("Has no GPU block index"); + } + + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + static const char* kBlockIndexNames[] = { + "blockIdx.x", + "blockIdx.y", + "blockIdx.z", + "blockIdx.w", + }; + + if (gpu_block_index_ < IDX_X || gpu_block_index_ > IDX_MAX) { + throw malformed_input("invalid GPU block index"); + } + + return kBlockIndexNames[gpu_block_index_]; + } + + void set_gpu_block_index(int index) { + if (index == IDX_UNSET) { + gpu_block_index_ = IDX_UNSET; + } + + if (is_gpu_thread_index()) { + throw std::runtime_error("Cannot set both gpu block and thread index"); + } + if (is_gpu_block_index() && gpu_block_index() != index) { + throw std::runtime_error("Cannot set a previously set block index"); + } + gpu_block_index_ = index; + } + + // GPU Thread Index + bool is_gpu_thread_index() const { + return gpu_thread_index() != IDX_UNSET; + } + + int gpu_thread_index() const { + return gpu_thread_index_; + } + + std::string gpu_thread_index_str() const { + if (!is_gpu_thread_index()) { + throw malformed_input("has no GPU thread index"); + } + + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + static const char* kThreadIndexNames[] = { + "threadIdx.x", "threadIdx.y", "threadIdx.z", "threadIdx.w"}; + + if (gpu_thread_index_ < IDX_X || gpu_thread_index_ > IDX_MAX) { + throw malformed_input("invalid GPU thread index"); + } + + return kThreadIndexNames[gpu_thread_index_]; + } + + void set_gpu_thread_index(int index) { + if (index == IDX_UNSET) { + gpu_thread_index_ = IDX_UNSET; + } + + if (is_gpu_block_index()) { + throw std::runtime_error("Cannot set both gpu thread and block index"); + } + if (is_gpu_thread_index() && gpu_thread_index() != index) { + throw std::runtime_error("Cannot set a previously set thread index"); + } + gpu_thread_index_ = index; + } + + void set_parallel() { + is_parallel_ = true; + } + + bool is_parallel() const { + return is_parallel_; + } + + std::string ToString() const { + if (is_gpu_block_index()) { + return gpu_block_index_str(); + } else if (is_gpu_thread_index()) { + return gpu_thread_index_str(); + } else if (is_parallel()) { + return "parallel"; + } + return ""; + } + + bool isDefault() const { + return gpu_block_index_ == IDX_UNSET && gpu_thread_index_ == IDX_UNSET && + !is_parallel_; + } + + void set_buffer_mapping(const std::unordered_map& map) { + map_input_to_tensor_bufs_ = map; + } + + std::unordered_map get_buffer_mapping() const { + return map_input_to_tensor_bufs_; + } + + private: + int gpu_block_index_{IDX_UNSET}; + int gpu_thread_index_{IDX_UNSET}; + bool is_parallel_{false}; + std::unordered_map map_input_to_tensor_bufs_; +}; + +class TORCH_API For : public StmtNode { + public: + VarPtr var() const { + return var_; + } + ExprPtr start() const { + return start_; + } + ExprPtr stop() const { + return stop_; + } + BlockPtr body() const { + return body_; + } + static ForPtr make( + const VarHandle& var, + const ExprHandle& start, + const ExprHandle& stop, + const StmtPtr& body) { + if (!body) { + return nullptr; + } + return alloc(var.node(), start.node(), stop.node(), body); + } + static ForPtr make( + const VarHandle& var, + const ExprHandle& start, + const ExprHandle& stop, + const StmtPtr& body, + const LoopOptions& loop_options) { + if (!body) { + return nullptr; + } + return alloc( + var.node(), start.node(), stop.node(), body, loop_options); + } + const LoopOptions loop_options() const { + return loop_options_; + } + + For(VarPtr var, ExprPtr start, ExprPtr stop, StmtPtr body) + : var_(std::move(var)), start_(std::move(start)), stop_(std::move(stop)) { + BlockPtr b = to(body); + if (!b) { + b = alloc(std::vector({std::move(body)})); + } + body_ = b; + set_parent(body_, this); + } + + For(VarPtr var, + ExprPtr start, + ExprPtr stop, + StmtPtr body, + LoopOptions loop_options) + : var_(std::move(var)), + start_(std::move(start)), + stop_(std::move(stop)), + loop_options_(std::move(loop_options)) { + if (!var_) { + throw malformed_input("invalid Var in For loop"); + } else if (!start_) { + throw malformed_input("invalid Start in For loop"); + } else if (!stop_) { + throw malformed_input("invalid Stop in For loop"); + } else if (!body || body->get_parent()) { + throw malformed_input("invalid Body in For loop"); + } + + BlockPtr b = to(body); + if (!b) { + b = alloc(std::vector({std::move(body)})); + } + body_ = b; + set_parent(body_, this); + } + + void set_gpu_block_index(int block_index) { + loop_options_.set_gpu_block_index(block_index); + } + + void set_gpu_thread_index(int thread_index) { + loop_options_.set_gpu_thread_index(thread_index); + } + + void set_parallel() { + loop_options_.set_parallel(); + } + + bool is_parallel() const { + return loop_options_.is_parallel(); + } + + void set_buffer_map(const std::unordered_map& map) { + loop_options_.set_buffer_mapping(map); + } + + ForPtr cloneWithNewBody(const StmtPtr& body) const { + return alloc(var_, start_, stop_, body, loop_options_); + } + + BlockPtr removeBody() { + auto res = body_; + set_parent(res, nullptr); + body_ = nullptr; + return res; + } + + void set_body(StmtPtr body) { + BlockPtr b = to(body); + if (!b) { + b = alloc(std::vector({std::move(body)})); + } + body_ = b; + set_parent(body_, this); + } + + void set_start(ExprPtr start) { + start_ = std::move(start); + } + + void set_stop(ExprPtr stop) { + stop_ = std::move(stop); + } + + void set_var(VarPtr var) { + var_ = std::move(var); + } + + private: + VarPtr var_; + ExprPtr start_; + ExprPtr stop_; + BlockPtr body_; + LoopOptions loop_options_; +}; + +// A backend specific IR Node that implements atomic-add. +// This node could only shows up as an internal with GPU backends. +// TODO: move to this an internal IR. +// TODO: make IR nodes extensible. +class TORCH_API AtomicAdd : public StmtNode { + public: + AtomicAdd(BufPtr buf, std::vector indices, ExprPtr value) + : buf_(std::move(buf)), + indices_(std::move(indices)), + value_(std::move(value)) {} + + VarPtr base_handle() const { + return buf_->base_handle(); + } + + BufPtr buf() const { + return buf_; + } + + ExprPtr flat_index() const { + TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened."); + return indices_[0]; + } + + ExprPtr value() const { + return value_; + } + + const std::vector& indices() const { + return indices_; + } + + void set_buf(BufPtr buf) { + buf_ = std::move(buf); + } + + void set_indices(std::vector indices) { + indices_ = std::move(indices); + } + + void set_value(ExprPtr value) { + value_ = std::move(value); + } + + private: + BufPtr buf_; + std::vector indices_; + ExprPtr value_; +}; + +class TORCH_API SyncThreads : public StmtNode { + public: + SyncThreads() = default; +}; + +/* + * ExternalCall statement represents a call to an external function that would + * compute the contents of the output buffer. An ExternalCall statement consists + * of: + * 1) output buffer - the buffer that'll be initialized by the call + * 2) external function name - a key from the NNC function registry to lookup + * the actual function to call + * 3) buffer arguments - the input buffers used by the function + * 4) non-buffer arguments - scalar arguments to pass to the function + * + * An example: + * A = nnc_conv2d(buf_args={Input, Weight, Bias}, args={1}) + * Here 'A' is the output buffer, "nnc_conv2d" is the function name, the buffer + * arguments are 'Input', 'Weight', and 'Bias', and there is a single non-buffer + * argument - 1. + * + * The semantics of the scalar arguments is defined solely by the implementation + * of the external function. + */ +class TORCH_API ExternalCall : public StmtNode { + public: + static ExternalCallPtr make( + BufHandle buf, + const std::string& func_name, + const std::vector& buf_args, + const std::vector& args); + + BufPtr buf() const { + return buf_; + } + + std::string func_name() const { + return func_name_; + } + + std::vector buf_args() const { + return buf_args_; + } + + std::vector args() const { + return args_; + } + + void set_buf(BufPtr buf) { + buf_ = std::move(buf); + } + + void set_buf_args(std::vector buf_args) { + buf_args_ = std::move(buf_args); + } + + void set_args(std::vector args) { + args_ = std::move(args); + } + + ExternalCall( + BufPtr buf, + std::string func_name, + std::vector buf_args, + std::vector args) + : buf_(std::move(buf)), + func_name_(std::move(func_name)), + buf_args_(std::move(buf_args)), + args_(std::move(args)) {} + + private: + BufPtr buf_; + std::string func_name_; + std::vector buf_args_; + std::vector args_; +}; + +class TORCH_API ExternalCallWithAlloc : public StmtNode { + public: + static ExternalCallWithAllocPtr make( + const std::string& func_name, + const std::vector& buf_out_args, + const std::vector& buf_args, + const std::vector& args); + + std::vector buf_out_args() const { + return buf_out_args_; + } + + std::string func_name() const { + return func_name_; + } + + std::vector buf_args() const { + return buf_args_; + } + + std::vector args() const { + return args_; + } + + void set_buf_out_args(std::vector buf_out_args) { + buf_out_args_ = std::move(buf_out_args); + } + + void set_buf_args(std::vector buf_args) { + buf_args_ = std::move(buf_args); + } + + void set_args(std::vector args) { + args_ = std::move(args); + } + + ExternalCallWithAlloc( + std::string func_name, + std::vector buf_out_args, + std::vector buf_args, + std::vector args) + : func_name_(std::move(func_name)), + buf_out_args_(std::move(buf_out_args)), + buf_args_(std::move(buf_args)), + args_(std::move(args)) {} + + private: + std::string func_name_; + std::vector buf_out_args_; + std::vector buf_args_; + std::vector args_; +}; + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/tensor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..c7b2e65b44771b5ac15d01910b9d20757d61d0e8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/tensor.h @@ -0,0 +1,321 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch::jit::tensorexpr { + +class TORCH_API Tensor { + public: + Tensor(BufPtr buf, const std::vector& args, const ExprPtr& body) + : buf_(std::move(buf)) { + stmt_ = constructStmt(args, body, {}, {}); + } + Tensor(BufHandle buf, const std::vector& args, ExprHandle body) + : Tensor(buf.node(), VarHandleVectorToVarVector(args), body.node()) {} + + Tensor( + BufPtr buf, + const std::vector& args, + const std::vector& reduce_dims, + const std::vector& reduce_args, + const ExprPtr& body) + : buf_(std::move(buf)) { + stmt_ = constructStmt(args, body, reduce_dims, reduce_args); + } + Tensor( + BufHandle buf, + const std::vector& args, + const std::vector& reduce_dims, + const std::vector& reduce_args, + ExprHandle body) + : Tensor( + buf.node(), + VarHandleVectorToVarVector(args), + ExprHandleVectorToExprVector(reduce_dims), + VarHandleVectorToVarVector(reduce_args), + body.node()) {} + + Tensor(BufPtr buf, StmtPtr stmt) + : buf_(std::move(buf)), stmt_(std::move(stmt)) {} + + BufPtr buf() const { + return buf_; + } + + StmtPtr stmt() const { + return stmt_; + } + + template + inline ExprHandle load(const std::vector& args) const; + template + inline ExprHandle load(const Ts&... ts) const; + + private: + StmtPtr constructStmt( + const std::vector& args, + const ExprPtr& body, + const std::vector& reduce_dims, + const std::vector& reduce_args) const; + + BufPtr buf_; + StmtPtr stmt_; +}; + +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const std::function& body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::function& body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const std::function& + body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::function& + body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const std::function< + ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& + body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::function< + ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& + body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const std::function& body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::function& body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const std::function&)>& body_func); +TORCH_API Tensor Compute( + const std::string& func_name, + const std::vector& dims, + const std::function&)>& body_func); + +inline std::vector create_index_vars( + const std::vector& dims) { + std::vector vars; + vars.reserve(dims.size()); + for (const ExprHandle& dim : dims) { + vars.emplace_back(alloc( + "i", dim.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)); + } + return vars; +} + +// Handle reductions over a Reducer and a body_func which produces values. +template +Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const Reducer& reducer, + const InitFunc& init_func, + const BodyFunc& body_func, + const std::vector& reduce_dims) { + std::vector vars = create_index_vars(dims); + std::vector reduce_vars = create_index_vars(reduce_dims); + + // If reduce_vars is empty, then it's not a reduction, but rather a simple + // copy + if (reduce_vars.empty()) { + ExprHandle body = Reducer::getReduceBody(body_func, vars); + BufHandle func_result = + Buf::make(func_name, dims, body.dtype(), std::nullopt, strides); + return Tensor(std::move(func_result), vars, std::move(body)); + } + + std::vector all_vars; + all_vars.insert(all_vars.end(), vars.begin(), vars.end()); + all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end()); + + ExprHandle body = Reducer::getReduceBody(body_func, all_vars); + std::vector output_args(vars.begin(), vars.end()); + ExprHandle init_expr = Cast::make(body.dtype(), init_func(vars)); + BufHandle func_result = Buf::make(func_name, dims, body.dtype(), init_expr); + + ExprHandle reduce_op = reducer(func_result, body, output_args, reduce_vars); + if (body.dtype() == kBFloat16) { + ExprHandle init_expr_acc = Cast::make(kFloat, init_func(vars)); + BufHandle func_result_acc = + Buf::make(func_name + "_acc", dims, kFloat, init_expr_acc); + reduce_op = reducer( + func_result, + std::move(func_result_acc), + body, + output_args, + reduce_vars); + } + + Tensor t = Tensor( + std::move(func_result), + vars, + reduce_dims, + reduce_vars, + std::move(reduce_op)); + return t; +} +template +Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const Reducer& reducer, + const InitFunc& init_func, + const BodyFunc& body_func, + const std::vector& reduce_dims) { + return Reduce( + func_name, + dims, + std::nullopt, + reducer, + init_func, + body_func, + reduce_dims); +} + +template +Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const Reducer& reducer, + const BodyFunc& body_func, + const std::vector& reduce_dims) { + return Reduce( + func_name, + dims, + strides, + reducer, + [&](ParameterList& p [[maybe_unused]]) { + return ExprHandle(reducer.initializer()); + }, + body_func, + reduce_dims); +} +template +Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const Reducer& reducer, + const BodyFunc& body_func, + const std::vector& reduce_dims) { + return Reduce( + func_name, dims, std::nullopt, reducer, body_func, reduce_dims); +} + +// Overload which allows inline lambda functions for the body_func. +template +Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const Reducer& reducer, + const BodyFunc&& body_func, + const std::vector& reduce_dims) { + return Reduce(func_name, dims, strides, reducer, body_func, reduce_dims); +} +template +Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const Reducer& reducer, + const BodyFunc&& body_func, + const std::vector& reduce_dims) { + return Reduce(func_name, dims, std::nullopt, reducer, body_func, reduce_dims); +} + +TORCH_API Tensor Reduce( + const std::string& name, + const std::vector& dims, + const std::optional>& strides, + const Reducer& reducer, + const BufHandle& buffer, + const std::vector& reduce_dims); +TORCH_API Tensor Reduce( + const std::string& name, + const std::vector& dims, + const Reducer& reducer, + const BufHandle& buffer, + const std::vector& reduce_dims); + +// Overload for the common case of all dimensions of a previously Computed +// Tensor. +TORCH_API Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const std::optional>& strides, + const Reducer& reducer, + const Tensor& tensor, + const std::vector& reduce_dims); +TORCH_API Tensor Reduce( + const std::string& func_name, + const std::vector& dims, + const Reducer& reducer, + const Tensor& tensor, + const std::vector& reduce_dims); + +template +inline ExprHandle Tensor::load(const Ts&... ts) const { + std::vector params({ExprHandle(ts)...}); + return Load::make(BufHandle(this->buf()), params); +} + +template +inline ExprHandle Tensor::load(const std::vector& args) const { + std::vector params(args.begin(), args.end()); + return Load::make(BufHandle(this->buf()), params); +} + +template +inline ExprHandle BufHandle::load(const Ts&... ts) const { + std::vector params({ExprHandle(ts)...}); + return ExprHandle(alloc(node(), ExprHandleVectorToExprVector(params))); +} + +template +inline ExprHandle BufHandle::load(const std::vector& args) const { + std::vector params(args.begin(), args.end()); + return ExprHandle(alloc(node(), ExprHandleVectorToExprVector(params))); +} + +inline ExprHandle BufHandle::load(const std::vector& args) const { + return this->template load(args); +} + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/tensorexpr_init.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/tensorexpr_init.h new file mode 100644 index 0000000000000000000000000000000000000000..1dc38660615824654df2b5f5b0d6a426eefc263a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/tensorexpr_init.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include + +namespace torch::jit { +// Initialize Python bindings for Tensor Expressions +void initTensorExprBindings(PyObject* module); +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/types.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/types.h new file mode 100644 index 0000000000000000000000000000000000000000..cd23fdce4ae98e169607045dc217b308eea23f8a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/types.h @@ -0,0 +1,158 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include + +namespace torch::jit::tensorexpr { + +using int32 = std::int32_t; + +class Dtype; +TORCH_API std::ostream& operator<<(std::ostream& stream, const Dtype& dtype); + +using ScalarType = c10::ScalarType; + +enum ElementType { + kAllTypes = 0, + kIntegralTypes = 1 << 0, + kFloatingPointTypes = 1 << 1, + kBoolType = 1 << 2, + kComplexTypes = 1 << 3, + kQintTypes = 1 << 4, + kNonComplexOrQintTypes = kIntegralTypes | kBoolType | kFloatingPointTypes, +}; + +// Data types for scalar and vector elements. +class TORCH_API Dtype { + public: + explicit Dtype(int8_t type) + : scalar_type_(static_cast(type)), lanes_(1) {} + explicit Dtype(ScalarType type) : scalar_type_(type), lanes_(1) {} + Dtype(int8_t type, int64_t lanes) + : scalar_type_(static_cast(type)), lanes_(lanes) {} + Dtype(ScalarType type, int64_t lanes) : scalar_type_(type), lanes_(lanes) {} + Dtype(Dtype type, int64_t lanes) + : scalar_type_(type.scalar_type_), lanes_(lanes) { + if (type.lanes() != 1) { + throw malformed_input("dtype lanes dont match"); + } + } + int64_t lanes() const { + return lanes_; + } + ScalarType scalar_type() const { + return scalar_type_; + } + Dtype scalar_dtype() const; + bool operator==(const Dtype& other) const { + return scalar_type_ == other.scalar_type_ && lanes_ == other.lanes_; + } + bool operator!=(const Dtype& other) const { + return !(*this == other); + } + int byte_size() const; + std::string ToCppString() const; + + bool is_integral() const { + return c10::isIntegralType(scalar_type_, true); + } + bool is_floating_point() const { + return c10::isFloatingType(scalar_type_); + } + bool is_signed() const { + return c10::isSignedType(scalar_type_); + } + + Dtype cloneWithScalarType(ScalarType nt) const { + return Dtype(nt, lanes_); + } + + private: + friend TORCH_API std::ostream& operator<<( + std::ostream& stream, + const Dtype& dtype); + ScalarType scalar_type_; + int64_t lanes_; // the width of the element for a vector time +}; + +extern TORCH_API Dtype kHandle; + +#define NNC_DTYPE_DECLARATION(ctype, name) extern TORCH_API Dtype k##name; + +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_DTYPE_DECLARATION) +NNC_DTYPE_DECLARATION(c10::quint8, QUInt8) +NNC_DTYPE_DECLARATION(c10::qint8, QInt8) +#undef NNC_DTYPE_DECLARATION + +template +TORCH_API Dtype ToDtype(); + +#define NNC_TODTYPE_DECLARATION(ctype, name) \ + template <> \ + inline Dtype ToDtype() { \ + return k##name; \ + } +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, NNC_TODTYPE_DECLARATION) +NNC_TODTYPE_DECLARATION(c10::quint8, QUInt8) +NNC_TODTYPE_DECLARATION(c10::qint8, QInt8) +#undef NNC_TODTYPE_DECLARATION + +TORCH_API Dtype ToDtype(ScalarType type); + +inline Dtype promoteTypes(Dtype a, Dtype b) { + if (a.lanes() != b.lanes()) { + throw malformed_input("promoting types with different lanes"); + } + return Dtype( + static_cast(c10::promoteTypes( + static_cast(a.scalar_type()), + static_cast(b.scalar_type()))), + a.lanes()); +} + +inline Dtype BinaryOpDtype( + Dtype op1_dtype, + Dtype op2_dtype, + ScalarType ret_type = ScalarType::Undefined) { + if (op1_dtype == op2_dtype) { + if (ret_type == ScalarType::Undefined) { + return op1_dtype; + } + + return ToDtype(ret_type); + } + + if (op1_dtype.lanes() != op2_dtype.lanes()) { + throw malformed_input("lanes dont match"); + } + int64_t lanes = op1_dtype.lanes(); + + Dtype resultType = promoteTypes(op1_dtype, op2_dtype); + if (resultType.scalar_type() == ScalarType::Undefined) { + throw malformed_input("scalar type doesn't match"); + } + + if (lanes == 1) { + // Use the fixed scalar Dtypes. + return ToDtype(resultType.scalar_type()); + } + + return resultType; +} + +} // namespace torch::jit::tensorexpr + +namespace std { + +using torch::jit::tensorexpr::Dtype; +std::string to_string(const Dtype& dtype); +using torch::jit::tensorexpr::ScalarType; +std::string to_string(const ScalarType& dtype); + +} // namespace std diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/unique_name_manager.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/unique_name_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..f5ceac667d158730af57c4ddc6192e59040e3d4b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/unique_name_manager.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch::jit::tensorexpr { + +class VarHandle; +class Var; + +using VarNameMap = std::unordered_map; + +// A manager to get unique names from vars. +// It starts with the name hints of the var and append "_" + $counter until it +// hits a unique name. +class TORCH_API UniqueNameManager { + public: + const std::string& get_unique_name(const VarHandle& v); + + const std::string& get_unique_name(const VarPtr& v); + + private: + friend class ScopedVarName; + VarNameMap unique_name_mapping_; + std::unordered_map unique_name_count_; + std::unordered_set all_unique_names_; +}; + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/var_substitutor.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/var_substitutor.h new file mode 100644 index 0000000000000000000000000000000000000000..e3009902bc33497303942c86a8958dccde2eba1d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/tensorexpr/var_substitutor.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch::jit::tensorexpr { + +using VarMapping = std::vector>; + +class VarSubMutator : public IRMutator { + public: + VarSubMutator(const VarMapping& var_mapping) { + for (auto& entry : var_mapping) { + VarPtr key_var = entry.first; + ExprPtr value = entry.second; + if (!key_var) { + throw malformed_input("missing key in VarSubMutator"); + } + var_mapping_[std::move(key_var)] = std::move(value); + } + } + + ExprPtr mutate(const VarPtr& var) override { + auto iter = var_mapping_.find(var); + if (iter == var_mapping_.end()) { + return var; + } + return iter->second; + } + + ExprPtr mutate(const ReduceOpPtr& var) override { + auto body = var->body()->accept_mutator(this); + std::vector new_inner; + + for (const auto& v : var->reduce_args()) { + ExprPtr e = v->accept_mutator(this); + if (VarPtr new_var = to(e)) { + new_inner.push_back(std::move(new_var)); + } else { + VarFinder varFinder; + e->accept(&varFinder); + auto varlist = varFinder.vars(); + new_inner.insert(new_inner.end(), varlist.begin(), varlist.end()); + } + } + + return alloc(body, new_inner, var->reducer()); + } + + private: + std::unordered_map var_mapping_; +}; + +} // namespace torch::jit::tensorexpr diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/catch_utils.hpp b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/catch_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e7696b137226304e40d754c3709fc14c586928b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/catch_utils.hpp @@ -0,0 +1,10 @@ +#pragma once + +#define CATCH_CONFIG_PREFIX_ALL +#include + +// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes +// warning; define our own version that doesn't warn. +#define _CATCH_REQUIRE_THROWS(...) \ + INTERNAL_CATCH_THROWS( \ + "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__) diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/file_check.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/file_check.h new file mode 100644 index 0000000000000000000000000000000000000000..fd09fcc6ad30b476e373e2a7c96f5c266d9bd04f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/file_check.h @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include + +namespace torch::jit { + +struct Graph; + +namespace testing { + +struct FileCheckImpl; + +struct FileCheck { + public: + TORCH_API explicit FileCheck(); + TORCH_API ~FileCheck(); + + // Run FileCheck against test string + TORCH_API void run(const std::string& test_string); + + // Run FileCheck against dump of graph IR + TORCH_API void run(const Graph& graph); + + // Parsing input checks string and run against test string / dump of graph IR + TORCH_API void run( + const std::string& input_checks_string, + const std::string& test_string); + TORCH_API void run( + const std::string& input_checks_string, + const Graph& graph); + + // Checks that the string occurs, starting at the end of the most recent match + TORCH_API FileCheck* check(const std::string& str); + + // Checks that the string does not occur between the previous match and next + // match. Consecutive check_nots test against the same previous match and next + // match + TORCH_API FileCheck* check_not(const std::string& str); + + // Checks that the string occurs on the same line as the previous match + TORCH_API FileCheck* check_same(const std::string& str); + + // Checks that the string occurs on the line immediately following the + // previous match + TORCH_API FileCheck* check_next(const std::string& str); + + // Checks that the string occurs count number of times, starting at the end + // of the previous match. If exactly is true, checks that there are exactly + // count many matches + TORCH_API FileCheck* check_count( + const std::string& str, + size_t count, + bool exactly = false); + + // A series of consecutive check_dags get turned into a group of checks + // which can appear in any order relative to each other. The checks begin + // at the end of the previous match, and the match for the check_dag group + // is the minimum match of all individual checks to the maximum match of all + // individual checks. + TORCH_API FileCheck* check_dag(const std::string& str); + + // Checks that source token is highlighted in str (usually an error message). + TORCH_API FileCheck* check_source_highlighted(const std::string& str); + + // Checks that the regex matched string occurs, starting at the end of the + // most recent match + TORCH_API FileCheck* check_regex(const std::string& str); + + // reset checks + TORCH_API void reset(); + + private: + bool has_run = false; + std::unique_ptr fcImpl; +}; +} // namespace testing +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/hooks_for_testing.h b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/hooks_for_testing.h new file mode 100644 index 0000000000000000000000000000000000000000..5613a0d24476d6f0918dd5befa48e5c1785c0b3a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/csrc/jit/testing/hooks_for_testing.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include +#include +#include + +namespace torch::jit { +struct Module; + +using ModuleHook = std::function; +using FunctionHook = std::function; + +TORCH_API void didFinishEmitModule(Module module); +TORCH_API void didFinishEmitFunction(StrongFunctionPtr defined); +TORCH_API void setEmitHooks(ModuleHook for_module, FunctionHook for_fn); + +TORCH_API std::pair getEmitHooks(); + +} // namespace torch::jit diff --git a/.venv/lib/python3.12/site-packages/torch/include/torch/headeronly/macros/Export.h b/.venv/lib/python3.12/site-packages/torch/include/torch/headeronly/macros/Export.h new file mode 100644 index 0000000000000000000000000000000000000000..183aeab563445c334af6921208147d315143ffbc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/include/torch/headeronly/macros/Export.h @@ -0,0 +1,87 @@ +#pragma once + +/* Header file to define the common scaffolding for exported symbols. + * + * Export is by itself a quite tricky situation to deal with, and if you are + * hitting this file, make sure you start with the background here: + * - Linux: https://gcc.gnu.org/wiki/Visibility + * - Windows: + * https://docs.microsoft.com/en-us/cpp/cpp/dllexport-dllimport?view=vs-2017 + * + * Do NOT include this file directly. Instead, use c10/macros/Macros.h + */ + +// You do not need to edit this part of file unless you are changing the core +// pytorch export abstractions. +// +// This part defines the C10 core export and import macros. This is controlled +// by whether we are building shared libraries or not, which is determined +// during build time and codified in c10/core/cmake_macros.h. +// When the library is built as a shared lib, EXPORT and IMPORT will contain +// visibility attributes. If it is being built as a static lib, then EXPORT +// and IMPORT basically have no effect. + +// As a rule of thumb, you should almost NEVER mix static and shared builds for +// libraries that depend on c10. AKA, if c10 is built as a static library, we +// recommend everything dependent on c10 to be built statically. If c10 is built +// as a shared library, everything dependent on it should be built as shared. In +// the PyTorch project, all native libraries shall use the macro +// C10_BUILD_SHARED_LIB to check whether pytorch is building shared or static +// libraries. + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifdef _WIN32 +#define C10_HIDDEN +#if defined(C10_BUILD_SHARED_LIBS) +#define C10_EXPORT __declspec(dllexport) +#define C10_IMPORT __declspec(dllimport) +#else +#define C10_EXPORT +#define C10_IMPORT +#endif +#else // _WIN32 +#if defined(__GNUC__) +#define C10_EXPORT __attribute__((__visibility__("default"))) +#define C10_HIDDEN __attribute__((__visibility__("hidden"))) +#else // defined(__GNUC__) +#define C10_EXPORT +#define C10_HIDDEN +#endif // defined(__GNUC__) +#define C10_IMPORT C10_EXPORT +#endif // _WIN32 + +#ifdef NO_EXPORT +#undef C10_EXPORT +#define C10_EXPORT +#endif + +// Definition of an adaptive XX_API macro, that depends on whether you are +// building the library itself or not, routes to XX_EXPORT and XX_IMPORT. +// Basically, you will need to do this for each shared library that you are +// building, and the instruction is as follows: assuming that you are building +// a library called libawesome.so. You should: +// (1) for your cmake target (usually done by "add_library(awesome, ...)"), +// define a macro called AWESOME_BUILD_MAIN_LIB using +// target_compile_options. +// (2) define the AWESOME_API macro similar to the one below. +// And in the source file of your awesome library, use AWESOME_API to +// annotate public symbols. + +// Here, for the C10 library, we will define the macro C10_API for both import +// and export. + +// This one is being used by libc10.so +#ifdef C10_BUILD_MAIN_LIB +#define C10_API C10_EXPORT +#else +#define C10_API C10_IMPORT +#endif diff --git a/.venv/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21f716e6fc575f96357529472aa0879a62040fd6 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/monitor/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..373dbb29550d59fe769fae0b58df3f1c1c906905 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1f02557fc25f2d0a81ba1c01d6d8fb7a432c385 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fake_quantize.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e920184426770b6a15ccebf64db237bcf4c75b8f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e9f3cd69955ec82e272bd94866e5dfcff5a7f12 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/fuser_method_mappings.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd048f265cd88d864dea39ab1030c03b7c82171 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/observer.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/qconfig.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/qconfig.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de972fd161421f4cb1f24d11ee5a65c811cdea9f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/qconfig.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb7efd82db47df0bad63fcd9673308fc1004a04 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quant_type.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bc5d1527f4245a75c5773e8c1d171a59212fb2e Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7307ca7e7b4ac4118fd6275d5db6ba3fc9fa72bb Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c781cac0a09f768b3b68c39326b794f9e0a6d1c Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/quantize_jit.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d90a4fae9bf423044c19b5f0e5319735ad1ea21f Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/quantization/__pycache__/stubs.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/__init__.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c01cbd457374c27e40b07daca5ae1644a701767d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/__init__.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.convert import convert +from torch.ao.quantization.fx.fuse import fuse + +# omitting files that's unlikely to be used right now, for example +# the newly added lower_to_fbgemm etc. +from torch.ao.quantization.fx.prepare import prepare diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/_equalize.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..7acea4f84a2a0a82f134b6790e573f8f1cb677f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/_equalize.py @@ -0,0 +1,38 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx._equalize import ( + _convert_equalization_ref, + _InputEqualizationObserver, + _WeightEqualizationObserver, + calculate_equalization_scale, + clear_weight_quant_obs_node, + convert_eq_obs, + CUSTOM_MODULE_SUPP_LIST, + custom_module_supports_equalization, + default_equalization_qconfig, + EqualizationQConfig, + fused_module_supports_equalization, + get_equalization_qconfig_dict, + get_layer_sqnr_dict, + get_op_node_and_weight_eq_obs, + input_equalization_observer, + is_equalization_observer, + maybe_get_next_equalization_scale, + maybe_get_next_input_eq_obs, + maybe_get_weight_eq_obs_node, + nn_module_supports_equalization, + node_supports_equalization, + remove_node, + reshape_scale, + scale_input_observer, + scale_weight_functional, + scale_weight_node, + update_obs_for_equalization, + weight_equalization_observer, +) diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/convert.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6ac350602bb7a97c773a3a09fec0780483379f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/convert.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.convert import convert diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/fuse.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..67527080304fb31ddc54fe254533e2196f77a616 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/fuse.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.fuse import fuse diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/fusion_patterns.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/fusion_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..e29337b3f861e5b54dc9f37d39d12ad975ad1315 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/fusion_patterns.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/graph_module.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..a71e980a57ba141bdc5bbe9b283d69582eb8fd82 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/graph_module.py @@ -0,0 +1,17 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.graph_module import ( + _is_observed_module, + _is_observed_standalone_module, + FusedGraphModule, + GraphModule, + ObservedGraphModule, + ObservedStandaloneGraphModule, + QuantizedGraphModule, +) diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/match_utils.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8b49f7c645d8d1bc3a154d62a1295a90b155f986 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/match_utils.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.match_utils import ( + _find_matches, + _is_match, + _MatchResult, + MatchAllNode, +) diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/pattern_utils.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a83e180fc4dbaa28d1d41a10037684f0afa6610 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/pattern_utils.py @@ -0,0 +1,35 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.pattern_utils import ( + _register_fusion_pattern, + _register_quant_pattern, + get_default_fusion_patterns, + get_default_output_activation_post_process_map, + get_default_quant_patterns, + QuantizeHandler, +) + + +# QuantizeHandler.__module__ = _NAMESPACE +_register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +_register_quant_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_quant_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_output_activation_post_process_map.__module__ = ( + "torch.ao.quantization.fx.pattern_utils" +) + +# __all__ = [ +# "QuantizeHandler", +# "_register_fusion_pattern", +# "get_default_fusion_patterns", +# "_register_quant_pattern", +# "get_default_quant_patterns", +# "get_default_output_activation_post_process_map", +# ] diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/prepare.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..ca65dcc04dd0021f0065892ca86e209a1c218473 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/prepare.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.prepare import prepare diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/quantization_patterns.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/quantization_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..20d8cc52ee4fb16843becec5487d9d4ee46681c9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/quantization_patterns.py @@ -0,0 +1,48 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.quantize_handler import ( + BatchNormQuantizeHandler, + BinaryOpQuantizeHandler, + CatQuantizeHandler, + ConvReluQuantizeHandler, + CopyNodeQuantizeHandler, + CustomModuleQuantizeHandler, + DefaultNodeQuantizeHandler, + EmbeddingQuantizeHandler, + FixedQParamsOpQuantizeHandler, + GeneralTensorShapeOpQuantizeHandler, + LinearReLUQuantizeHandler, + QuantizeHandler, + RNNDynamicQuantizeHandler, + StandaloneModuleQuantizeHandler, +) + + +QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +ConvReluQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +LinearReLUQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BatchNormQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +EmbeddingQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +RNNDynamicQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +DefaultNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +FixedQParamsOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +CopyNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CustomModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +GeneralTensorShapeOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +StandaloneModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/quantization_types.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/quantization_types.py new file mode 100644 index 0000000000000000000000000000000000000000..a422cdd3142e04c8d16f495cc6cd65823451810b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/quantization_types.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/.venv/lib/python3.12/site-packages/torch/quantization/fx/utils.py b/.venv/lib/python3.12/site-packages/torch/quantization/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef35559884b7c430f1d5c72b21f72979108469a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/quantization/fx/utils.py @@ -0,0 +1,20 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.utils import ( + all_node_args_have_no_tensors, + assert_and_get_unique_device, + create_getattr_from_value, + get_custom_module_class_keys, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, + get_non_observable_arg_indexes_and_types, + get_qconv_prepack_op, + graph_module_from_producer_nodes, + maybe_get_next_module, +) diff --git a/.venv/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b315d4254d0494d7dd8083030a3db6c8ec2656e3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/signal/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/signal/windows/__init__.py b/.venv/lib/python3.12/site-packages/torch/signal/windows/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6749a92c6fc1525ea95c7d4d1e398229ab10b7a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/signal/windows/__init__.py @@ -0,0 +1,28 @@ +from .windows import ( + bartlett, + blackman, + cosine, + exponential, + gaussian, + general_cosine, + general_hamming, + hamming, + hann, + kaiser, + nuttall, +) + + +__all__ = [ + "bartlett", + "blackman", + "cosine", + "exponential", + "gaussian", + "general_cosine", + "general_hamming", + "hamming", + "hann", + "kaiser", + "nuttall", +] diff --git a/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca8ed2d4774f459ad8c2e2b6eadc6e2b59c14c26 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..357b94e515c99714357201709fedde67a2668d60 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/signal/windows/__pycache__/windows.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/signal/windows/windows.py b/.venv/lib/python3.12/site-packages/torch/signal/windows/windows.py new file mode 100644 index 0000000000000000000000000000000000000000..7d67de3f83848ec7d84eb8d3453db3e28d2413fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/signal/windows/windows.py @@ -0,0 +1,890 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterable +from math import sqrt +from typing import Callable, Optional, TypeVar + +import torch +from torch import Tensor +from torch._torch_docs import factory_common_args, merge_dicts, parse_kwargs + + +__all__ = [ + "bartlett", + "blackman", + "cosine", + "exponential", + "gaussian", + "general_cosine", + "general_hamming", + "hamming", + "hann", + "kaiser", + "nuttall", +] + +_T = TypeVar("_T") + +window_common_args = merge_dicts( + parse_kwargs( + """ + M (int): the length of the window. + In other words, the number of points of the returned window. + sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis. + If `True`, returns a symmetric window suitable for use in filter design. Default: `True`. +""" + ), + factory_common_args, + { + "normalization": "The window is normalized to 1 (maximum value is 1). However, the 1 doesn't appear if " + ":attr:`M` is even and :attr:`sym` is `True`.", + }, +) + + +def _add_docstr(*args: str) -> Callable[[_T], _T]: + r"""Adds docstrings to a given decorated function. + + Specially useful when then docstrings needs string interpolation, e.g., with + str.format(). + REMARK: Do not use this function if the docstring doesn't need string + interpolation, just write a conventional docstring. + + Args: + args (str): + """ + + def decorator(o: _T) -> _T: + o.__doc__ = "".join(args) + return o + + return decorator + + +def _window_function_checks( + function_name: str, M: int, dtype: torch.dtype, layout: torch.layout +) -> None: + r"""Performs common checks for all the defined windows. + This function should be called before computing any window. + + Args: + function_name (str): name of the window function. + M (int): length of the window. + dtype (:class:`torch.dtype`): the desired data type of returned tensor. + layout (:class:`torch.layout`): the desired layout of returned tensor. + """ + if M < 0: + raise ValueError( + f"{function_name} requires non-negative window length, got M={M}" + ) + if layout is not torch.strided: + raise ValueError( + f"{function_name} is implemented for strided tensors only, got: {layout}" + ) + if dtype not in [torch.float32, torch.float64]: + raise ValueError( + f"{function_name} expects float32 or float64 dtypes, got: {dtype}" + ) + + +@_add_docstr( + r""" +Computes a window with an exponential waveform. +Also known as Poisson window. + +The exponential window is defined as follows: + +.. math:: + w_n = \exp{\left(-\frac{|n - c|}{\tau}\right)} + +where `c` is the ``center`` of the window. + """, + r""" + +{normalization} + +Args: + {M} + +Keyword args: + center (float, optional): where the center of the window will be located. + Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`. + tau (float, optional): the decay value. + Tau is generally associated with a percentage, that means, that the value should + vary within the interval (0, 100]. If tau is 100, it is considered the uniform window. + Default: 1.0. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric exponential window of size 10 and with a decay value of 1.0. + >>> # The center will be at (M - 1) / 2, where M is 10. + >>> torch.signal.windows.exponential(10) + tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111]) + + >>> # Generates a periodic exponential window and decay factor equal to .5 + >>> torch.signal.windows.exponential(10, sym=False,tau=.5) + tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04]) + """.format( + **window_common_args + ), +) +def exponential( + M: int, + *, + center: Optional[float] = None, + tau: float = 1.0, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("exponential", M, dtype, layout) + + if tau <= 0: + raise ValueError(f"Tau must be positive, got: {tau} instead.") + + if sym and center is not None: + raise ValueError("Center must be None for symmetric windows") + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if center is None: + center = (M if not sym and M > 1 else M - 1) / 2.0 + + constant = 1 / tau + + k = torch.linspace( + start=-center * constant, + end=(-center + (M - 1)) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.exp(-torch.abs(k)) + + +@_add_docstr( + r""" +Computes a window with a simple cosine waveform, following the same implementation as SciPy. +This window is also known as the sine window. + +The cosine window is defined as follows: + +.. math:: + w_n = \sin\left(\frac{\pi (n + 0.5)}{M}\right) + +This formula differs from the typical cosine window formula by incorporating a 0.5 term in the numerator, +which shifts the sample positions. This adjustment results in a window that starts and ends with non-zero values. + +""", + r""" + +{normalization} + +Args: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric cosine window. + >>> torch.signal.windows.cosine(10) + tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540, 0.1564]) + + >>> # Generates a periodic cosine window. + >>> torch.signal.windows.cosine(10, sym=False) + tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549, 0.4154]) +""".format( + **window_common_args, + ), +) +def cosine( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("cosine", M, dtype, layout) + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + start = 0.5 + constant = torch.pi / (M + 1 if not sym and M > 1 else M) + + k = torch.linspace( + start=start * constant, + end=(start + (M - 1)) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.sin(k) + + +@_add_docstr( + r""" +Computes a window with a gaussian waveform. + +The gaussian window is defined as follows: + +.. math:: + w_n = \exp{\left(-\left(\frac{n}{2\sigma}\right)^2\right)} + """, + r""" + +{normalization} + +Args: + {M} + +Keyword args: + std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is. + Default: 1.0. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric gaussian window with a standard deviation of 1.0. + >>> torch.signal.windows.gaussian(10) + tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05]) + + >>> # Generates a periodic gaussian window and standard deviation equal to 0.9. + >>> torch.signal.windows.gaussian(10, sym=False,std=0.9) + tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05]) +""".format( + **window_common_args, + ), +) +def gaussian( + M: int, + *, + std: float = 1.0, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("gaussian", M, dtype, layout) + + if std <= 0: + raise ValueError(f"Standard deviation must be positive, got: {std} instead.") + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + start = -(M if not sym and M > 1 else M - 1) / 2.0 + + constant = 1 / (std * sqrt(2)) + + k = torch.linspace( + start=start * constant, + end=(start + (M - 1)) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.exp(-(k**2)) + + +@_add_docstr( + r""" +Computes the Kaiser window. + +The Kaiser window is defined as follows: + +.. math:: + w_n = I_0 \left( \beta \sqrt{1 - \left( {\frac{n - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta ) + +where ``I_0`` is the zeroth order modified Bessel function of the first kind (see :func:`torch.special.i0`), and +``N = M - 1 if sym else M``. + """, + r""" + +{normalization} + +Args: + {M} + +Keyword args: + beta (float, optional): shape parameter for the window. Must be non-negative. Default: 12.0 + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric gaussian window with a standard deviation of 1.0. + >>> torch.signal.windows.kaiser(5) + tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05]) + >>> # Generates a periodic gaussian window and standard deviation equal to 0.9. + >>> torch.signal.windows.kaiser(5, sym=False,std=0.9) + tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05]) +""".format( + **window_common_args, + ), +) +def kaiser( + M: int, + *, + beta: float = 12.0, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("kaiser", M, dtype, layout) + + if beta < 0: + raise ValueError(f"beta must be non-negative, got: {beta} instead.") + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if M == 1: + return torch.ones( + (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + # Avoid NaNs by casting `beta` to the appropriate dtype. + beta = torch.tensor(beta, dtype=dtype, device=device) + + start = -beta + constant = 2.0 * beta / (M if not sym else M - 1) + end = torch.minimum(beta, start + (M - 1) * constant) + + k = torch.linspace( + start=start, + end=end, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(beta) + + +@_add_docstr( + r""" +Computes the Hamming window. + +The Hamming window is defined as follows: + +.. math:: + w_n = \alpha - \beta\ \cos \left( \frac{2 \pi n}{M - 1} \right) + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + alpha (float, optional): The coefficient :math:`\alpha` in the equation above. + beta (float, optional): The coefficient :math:`\beta` in the equation above. + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Hamming window. + >>> torch.signal.windows.hamming(10) + tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800]) + + >>> # Generates a periodic Hamming window. + >>> torch.signal.windows.hamming(10, sym=False) + tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679]) +""".format( + **window_common_args + ), +) +def hamming( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_hamming( + M, + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the Hann window. + +The Hann window is defined as follows: + +.. math:: + w_n = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{M - 1} \right)\right] = + \sin^2 \left( \frac{\pi n}{M - 1} \right) + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Hann window. + >>> torch.signal.windows.hann(10) + tensor([0.0000, 0.1170, 0.4132, 0.7500, 0.9698, 0.9698, 0.7500, 0.4132, 0.1170, 0.0000]) + + >>> # Generates a periodic Hann window. + >>> torch.signal.windows.hann(10, sym=False) + tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +""".format( + **window_common_args + ), +) +def hann( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_hamming( + M, + alpha=0.5, + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the Blackman window. + +The Blackman window is defined as follows: + +.. math:: + w_n = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{M - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{M - 1} \right) + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Blackman window. + >>> torch.signal.windows.blackman(5) + tensor([-1.4901e-08, 3.4000e-01, 1.0000e+00, 3.4000e-01, -1.4901e-08]) + + >>> # Generates a periodic Blackman window. + >>> torch.signal.windows.blackman(5, sym=False) + tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01]) +""".format( + **window_common_args + ), +) +def blackman( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("blackman", M, dtype, layout) + + return general_cosine( + M, + a=[0.42, 0.5, 0.08], + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the Bartlett window. + +The Bartlett window is defined as follows: + +.. math:: + w_n = 1 - \left| \frac{2n}{M - 1} - 1 \right| = \begin{cases} + \frac{2n}{M - 1} & \text{if } 0 \leq n \leq \frac{M - 1}{2} \\ + 2 - \frac{2n}{M - 1} & \text{if } \frac{M - 1}{2} < n < M \\ \end{cases} + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Bartlett window. + >>> torch.signal.windows.bartlett(10) + tensor([0.0000, 0.2222, 0.4444, 0.6667, 0.8889, 0.8889, 0.6667, 0.4444, 0.2222, 0.0000]) + + >>> # Generates a periodic Bartlett window. + >>> torch.signal.windows.bartlett(10, sym=False) + tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000]) +""".format( + **window_common_args + ), +) +def bartlett( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("bartlett", M, dtype, layout) + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if M == 1: + return torch.ones( + (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + start = -1 + constant = 2 / (M if not sym else M - 1) + + k = torch.linspace( + start=start, + end=start + (M - 1) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + return 1 - torch.abs(k) + + +@_add_docstr( + r""" +Computes the general cosine window. + +The general cosine window is defined as follows: + +.. math:: + w_n = \sum^{M-1}_{i=0} (-1)^i a_i \cos{ \left( \frac{2 \pi i n}{M - 1}\right)} + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + a (Iterable): the coefficients associated to each of the cosine functions. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric general cosine window with 3 coefficients. + >>> torch.signal.windows.general_cosine(10, a=[0.46, 0.23, 0.31], sym=True) + tensor([0.5400, 0.3376, 0.1288, 0.4200, 0.9136, 0.9136, 0.4200, 0.1288, 0.3376, 0.5400]) + + >>> # Generates a periodic general cosine window with 2 coefficients. + >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) + tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +""".format( + **window_common_args + ), +) +def general_cosine( + M, + *, + a: Iterable, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + if dtype is None: + dtype = torch.get_default_dtype() + + _window_function_checks("general_cosine", M, dtype, layout) + + if M == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if M == 1: + return torch.ones( + (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + if not isinstance(a, Iterable): + raise TypeError("Coefficients must be a list/tuple") + + if not a: + raise ValueError("Coefficients cannot be empty") + + constant = 2 * torch.pi / (M if not sym else M - 1) + + k = torch.linspace( + start=0, + end=(M - 1) * constant, + steps=M, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + a_i = torch.tensor( + [(-1) ** i * w for i, w in enumerate(a)], + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + i = torch.arange( + a_i.shape[0], + dtype=a_i.dtype, + device=a_i.device, + requires_grad=a_i.requires_grad, + ) + return (a_i.unsqueeze(-1) * torch.cos(i.unsqueeze(-1) * k)).sum(0) + + +@_add_docstr( + r""" +Computes the general Hamming window. + +The general Hamming window is defined as follows: + +.. math:: + w_n = \alpha - (1 - \alpha) \cos{ \left( \frac{2 \pi n}{M-1} \right)} + """, + r""" + +{normalization} + +Arguments: + {M} + +Keyword args: + alpha (float, optional): the window coefficient. Default: 0.54. + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +Examples:: + + >>> # Generates a symmetric Hamming window with the general Hamming window. + >>> torch.signal.windows.general_hamming(10, sym=True) + tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800]) + + >>> # Generates a periodic Hann window with the general Hamming window. + >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False) + tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +""".format( + **window_common_args + ), +) +def general_hamming( + M, + *, + alpha: float = 0.54, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_cosine( + M, + a=[alpha, 1.0 - alpha], + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) + + +@_add_docstr( + r""" +Computes the minimum 4-term Blackman-Harris window according to Nuttall. + +.. math:: + w_n = 1 - 0.36358 \cos{(z_n)} + 0.48917 \cos{(2z_n)} - 0.13659 \cos{(3z_n)} + 0.01064 \cos{(4z_n)} + +where :math:`z_n = \frac{2 \pi n}{M}`. + """, + """ + +{normalization} + +Arguments: + {M} + +Keyword args: + {sym} + {dtype} + {layout} + {device} + {requires_grad} + +References:: + + - A. Nuttall, "Some windows with very good sidelobe behavior," + IEEE Transactions on Acoustics, Speech, and Signal Processing, vol. 29, no. 1, pp. 84-91, + Feb 1981. https://doi.org/10.1109/TASSP.1981.1163506 + + - Heinzel G. et al., "Spectrum and spectral density estimation by the Discrete Fourier transform (DFT), + including a comprehensive list of window functions and some new flat-top windows", + February 15, 2002 https://holometer.fnal.gov/GH_FFT.pdf + +Examples:: + + >>> # Generates a symmetric Nutall window. + >>> torch.signal.windows.general_hamming(5, sym=True) + tensor([3.6280e-04, 2.2698e-01, 1.0000e+00, 2.2698e-01, 3.6280e-04]) + + >>> # Generates a periodic Nuttall window. + >>> torch.signal.windows.general_hamming(5, sym=False) + tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01]) +""".format( + **window_common_args + ), +) +def nuttall( + M: int, + *, + sym: bool = True, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + requires_grad: bool = False, +) -> Tensor: + return general_cosine( + M, + a=[0.3635819, 0.4891775, 0.1365995, 0.0106411], + sym=sym, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) diff --git a/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6e8403ab54ef91822381c46cec5f1541d4a4102 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/__init__.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e906a5e488a3e9315ec508e1f409d7fdbec5ee3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_conversions.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d2b7b5c2d919797a099342f4913d5462fb763c6 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/_semi_structured_ops.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..666b1b799c68455327fb113e5e99361a57d980b3 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/sparse/__pycache__/semi_structured.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_import_utils.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_import_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0a357553c879a045df97269f1f9873de4453cb5 Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/_import_utils.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/model_zoo.cpython-312.pyc b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/model_zoo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c61009217922e5446574c7f7d08399456e078ea Binary files /dev/null and b/.venv/lib/python3.12/site-packages/torch/utils/__pycache__/model_zoo.cpython-312.pyc differ diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py b/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c1ee83bc1c3f58b00d85a511a8e8950df9f4d1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/_strobelight/cli_function_profiler.py @@ -0,0 +1,312 @@ +# mypy: disallow-untyped-defs + +import functools +import logging +import os +import re +import subprocess +import time +from collections.abc import Sequence +from threading import Lock +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + + +logger = logging.getLogger("strobelight_function_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +class StrobelightCLIProfilerError(Exception): + """ + Raised when an error happens during strobelight profiling + """ + + +def _pid_namespace_link(pid: Optional[int] = None) -> str: + """Returns the link to the process's namespace, example: pid:[4026531836]""" + PID_NAMESPACE_PATH = "/proc/{}/ns/pid" + pid = pid or os.getpid() + return os.readlink(PID_NAMESPACE_PATH.format(pid)) + + +def _pid_namespace(pid: Optional[int] = None) -> int: + """Returns the process's namespace id""" + pid = pid or os.getpid() + link = _pid_namespace_link(pid) + return int(link[link.find("[") + 1 : -1]) + + +def _command_to_string(command: Sequence[str]) -> str: + return " ".join(command) + + +class StrobelightCLIFunctionProfiler: + """ + Note: this is a meta only tool. + + StrobelightCLIFunctionProfiler can be used to profile a python function and + generate a strobelight link with the results. It works on meta servers but + does not requries an fbcode target. + When stop_at_error is false(default), error during profiling does not prevent + the work function from running. + + Check function_profiler_example.py for an example. + """ + + # This lock is used to make sure only one thread is running the profiler at any point. + _lock = Lock() + + def __init__( + self, + *, + stop_at_error: bool = False, + max_profile_duration_sec: int = 60 * 10, + sample_each: float = 1e7, # sample each sample_each cycles. + run_user_name: str = "pytorch-strobelight-ondemand", + timeout_wait_for_running_sec: int = 60, + timeout_wait_for_finished_sec: int = 60, + recorded_env_variables: Optional[list[str]] = None, + sample_tags: Optional[list[str]] = None, + stack_max_len: int = 127, + async_stack_max_len: int = 127, + ): + self.stop_at_error = stop_at_error + self.max_profile_duration_sec = max_profile_duration_sec + self.sample_each = sample_each + self.run_user_name = run_user_name + self.timeout_wait_for_running_sec = timeout_wait_for_running_sec + self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec + # Results of the most recent run. + # Tracks the strobelight run id of the most recent run + self.current_run_id: Optional[int] = None + self.sample_tags = sample_tags + + def _run_async(self) -> None: + processId = os.getpid() + namespace = _pid_namespace(processId) + command = [ + "strobeclient", + "run", + "--profiler", + "pyperf", + "--event", + "cycles", + "--async", + "--sample-interval", + f"{int(self.sample_each)}", + "--duration-ms", + f"{int(self.max_profile_duration_sec * 1000)}", + "--pid", + f"{namespace}:{processId}", + ] + + if self.sample_tags: + command.append("--sample-tags") + command.append(",".join(self.sample_tags)) + + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in run_async:{output}" + ) + + if match := re.search(r"INFO Run Id: (-?\d+)", output): + self.current_run_id = int(match.group(1)) + return + + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, unexpected result {output}" + ) + + def _wait_for_running(self, counter: int = 0) -> None: + if counter > 20: + raise StrobelightCLIProfilerError( + "wait_for_running called more than 20 times" + ) + + command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in wait_for_running:{output}" + ) + + if match := re.search("Profile run status: (.*)", output): + current_status = match.group(1) + if current_status == "RUNNING": + return + elif current_status == "PREPARING": + time.sleep(10) + self._wait_for_running(counter + 1) + return + else: + raise StrobelightCLIProfilerError(f"unexpected {current_status} phase") + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _stop_run(self) -> None: + command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, return code is not 0 :{output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Success!"): + return + else: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, got {current_status} result" + ) + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _get_results(self) -> None: + command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, return code is not 0 : {output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Profile run status: PROCESSING"): + time.sleep(10) + self._get_results() + return + elif not current_status.__contains__("Profile run finished with SUCCESS"): + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, unexpected response {output}" + ) + + for item in re.findall( + r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))", + output, + ): + logger.info(item[0]) + + def _stop_strobelight_no_throw( + self, + collect_results: bool, + ) -> None: + try: + # call stop run + self._stop_run() + logger.info("strobelight profiling stopped") + + logger.debug("collection stopped") + + if not collect_results: + return + + self._get_results() + except Exception: + logger.warning("error during stop_strobelight", exc_info=True) + + # Return true if strobelight started and is running. Never throw. + def _start_strobelight(self) -> bool: + strobelight_started = False + try: + self._run_async() + strobelight_started = True + logger.info("strobelight run id is: %s", self.current_run_id) + self._wait_for_running() + logger.info("strobelight profiling running") + return True + + except Exception: + logger.warning("error during start_strobelight:", exc_info=True) + if strobelight_started: + self._stop_strobelight_no_throw(collect_results=False) + return False + + def profile( + self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs + ) -> Optional[_R]: + self.current_run_id = None + + if locked := StrobelightCLIFunctionProfiler._lock.acquire(False): + if not locked: + if self.stop_at_error: + raise StrobelightCLIProfilerError("concurrent runs not supported") + + logger.warning("concurrent runs not supported") + return work_function(*args, **kwargs) + + started = self._start_strobelight() + if not started: + if self.stop_at_error: + StrobelightCLIFunctionProfiler._lock.release() + raise StrobelightCLIProfilerError( + "failed to start strobelight profiling" + ) + result = work_function(*args, **kwargs) + StrobelightCLIFunctionProfiler._lock.release() + return result + + try: + logger.debug("collection started") + result = work_function(*args, **kwargs) + self._stop_strobelight_no_throw(collect_results=True) + StrobelightCLIFunctionProfiler._lock.release() + return result + except Exception as error: + logger.warning("work function throw exception", exc_info=True) + self._stop_strobelight_no_throw(collect_results=False) + StrobelightCLIFunctionProfiler._lock.release() + raise error + return None + + +# A function decorator that wraps profile, if no profiler is provided one with +# default args is created. A function can be annotated as: +# @strobelight() +# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..)) +# @strobelight(stop_at_error=True,...) +def strobelight( + profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any +) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]: + if not profiler: + profiler = StrobelightCLIFunctionProfiler(**kwargs) + + def strobelight_inner( + work_function: Callable[_P, _R] + ) -> Callable[_P, Optional[_R]]: + @functools.wraps(work_function) + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + return profiler.profile(work_function, *args, **kwargs) + + return wrapper_function + + return strobelight_inner diff --git a/.venv/lib/python3.12/site-packages/torch/utils/benchmark/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e814aaf4671ca35484c43bc38677849d02a81ec --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/benchmark/__init__.py @@ -0,0 +1,6 @@ +from torch.utils.benchmark.utils.common import * # noqa: F403 +from torch.utils.benchmark.utils.timer import * # noqa: F403 +from torch.utils.benchmark.utils.compare import * # noqa: F403 +from torch.utils.benchmark.utils.fuzzer import * # noqa: F403 +from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import * # noqa: F403 +from torch.utils.benchmark.utils.sparse_fuzzer import * # noqa: F403 diff --git a/.venv/lib/python3.12/site-packages/torch/utils/bottleneck/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/bottleneck/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/utils/bottleneck/__main__.py b/.venv/lib/python3.12/site-packages/torch/utils/bottleneck/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8bc43be0e2bbb7aed97cda2c10e45895d6071b9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/bottleneck/__main__.py @@ -0,0 +1,229 @@ +# mypy: allow-untyped-defs +import argparse +import cProfile +import pstats +import sys +import os + +import torch +from torch.autograd import profiler +from torch.utils.collect_env import get_env_info + + +def redirect_argv(new_argv): + sys.argv[:] = new_argv[:] + + +def compiled_with_cuda(sysinfo): + if sysinfo.cuda_compiled_version: + return f'compiled w/ CUDA {sysinfo.cuda_compiled_version}' + return 'not compiled w/ CUDA' + + +env_summary = """ +-------------------------------------------------------------------------------- + Environment Summary +-------------------------------------------------------------------------------- +PyTorch {pytorch_version}{debug_str} {cuda_compiled} +Running with Python {py_version} and {cuda_runtime} + +`{pip_version} list` truncated output: +{pip_list_output} +""".strip() + + +def run_env_analysis(): + print('Running environment analysis...') + info = get_env_info() + + result: dict[str, str] = {} + + debug_str = '' + if info.is_debug_build: + debug_str = ' DEBUG' + + cuda_avail = '' + if info.is_cuda_available: + cuda = info.cuda_runtime_version + if cuda is not None: + cuda_avail = 'CUDA ' + cuda + else: + cuda = 'CUDA unavailable' + + pip_version = info.pip_version + pip_list_output = info.pip_packages + if pip_list_output is None: + pip_list_output = 'Unable to fetch' + + result = { + 'debug_str': debug_str, + 'pytorch_version': info.torch_version, + 'cuda_compiled': compiled_with_cuda(info), + 'py_version': f'{sys.version_info[0]}.{sys.version_info[1]}', + 'cuda_runtime': cuda_avail, + 'pip_version': pip_version, + 'pip_list_output': pip_list_output, + } + + return env_summary.format(**result) + + +def run_cprofile(code, globs, launch_blocking=False): + print('Running your script with cProfile') + prof = cProfile.Profile() + prof.enable() + exec(code, globs, None) + prof.disable() + return prof + + +cprof_summary = """ +-------------------------------------------------------------------------------- + cProfile output +-------------------------------------------------------------------------------- +""".strip() + + +def print_cprofile_summary(prof, sortby='tottime', topk=15): + print(cprof_summary) + cprofile_stats = pstats.Stats(prof).sort_stats(sortby) + cprofile_stats.print_stats(topk) + + +def run_autograd_prof(code, globs): + def run_prof(use_cuda=False): + with profiler.profile(use_cuda=use_cuda) as prof: + exec(code, globs, None) + return prof + + print('Running your script with the autograd profiler...') + result = [run_prof(use_cuda=False)] + if torch.cuda.is_available(): + result.append(run_prof(use_cuda=True)) + else: + result.append(None) + + return result + + +autograd_prof_summary = """ +-------------------------------------------------------------------------------- + autograd profiler output ({mode} mode) +-------------------------------------------------------------------------------- + {description} +{cuda_warning} +{output} +""".strip() + + +def print_autograd_prof_summary(prof, mode, sortby='cpu_time', topk=15): + valid_sortby = ['cpu_time', 'cuda_time', 'cpu_time_total', 'cuda_time_total', 'count'] + if sortby not in valid_sortby: + warn = ('WARNING: invalid sorting option for autograd profiler results: {}\n' + 'Expected `cpu_time`, `cpu_time_total`, or `count`. ' + 'Defaulting to `cpu_time`.') + print(warn.format(sortby)) + sortby = 'cpu_time' + + if mode == 'CUDA': + cuda_warning = ('\n\tBecause the autograd profiler uses the CUDA event API,\n' + '\tthe CUDA time column reports approximately max(cuda_time, cpu_time).\n' + '\tPlease ignore this output if your code does not use CUDA.\n') + else: + cuda_warning = '' + + sorted_events = sorted(prof.function_events, + key=lambda x: getattr(x, sortby), reverse=True) + topk_events = sorted_events[:topk] + + result = { + 'mode': mode, + 'description': f'top {topk} events sorted by {sortby}', + 'output': torch.autograd.profiler_util._build_table(topk_events), + 'cuda_warning': cuda_warning + } + + print(autograd_prof_summary.format(**result)) + + +descript = """ +`bottleneck` is a tool that can be used as an initial step for debugging +bottlenecks in your program. + +It summarizes runs of your script with the Python profiler and PyTorch\'s +autograd profiler. Because your script will be profiled, please ensure that it +exits in a finite amount of time. + +For more complicated uses of the profilers, please see +https://docs.python.org/3/library/profile.html and +https://pytorch.org/docs/main/autograd.html#profiler for more information. +""".strip() + + +def parse_args(): + parser = argparse.ArgumentParser(description=descript) + parser.add_argument('scriptfile', type=str, + help='Path to the script to be run. ' + 'Usually run with `python path/to/script`.') + parser.add_argument('args', type=str, nargs=argparse.REMAINDER, + help='Command-line arguments to be passed to the script.') + return parser.parse_args() + + +def cpu_time_total(autograd_prof): + return sum(event.cpu_time_total for event in autograd_prof.function_events) + + +def main(): + args = parse_args() + + # Customizable constants. + scriptfile = args.scriptfile + scriptargs = [] if args.args is None else args.args + scriptargs.insert(0, scriptfile) + cprofile_sortby = 'tottime' + cprofile_topk = 15 + autograd_prof_sortby = 'cpu_time_total' + autograd_prof_topk = 15 + + redirect_argv(scriptargs) + + sys.path.insert(0, os.path.dirname(scriptfile)) + with open(scriptfile, 'rb') as stream: + code = compile(stream.read(), scriptfile, 'exec') + globs = { + '__file__': scriptfile, + '__name__': '__main__', + '__package__': None, + '__cached__': None, + } + + print(descript) + + env_summary = run_env_analysis() + + if torch.cuda.is_available(): + torch.cuda.init() + cprofile_prof = run_cprofile(code, globs) + autograd_prof_cpu, autograd_prof_cuda = run_autograd_prof(code, globs) + + print(env_summary) + print_cprofile_summary(cprofile_prof, cprofile_sortby, cprofile_topk) + + if not torch.cuda.is_available(): + print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk) + return + + # Print both the result of the CPU-mode and CUDA-mode autograd profilers + # if their execution times are very different. + cuda_prof_exec_time = cpu_time_total(autograd_prof_cuda) + if len(autograd_prof_cpu.function_events) > 0: + cpu_prof_exec_time = cpu_time_total(autograd_prof_cpu) + pct_diff = (cuda_prof_exec_time - cpu_prof_exec_time) / cuda_prof_exec_time + if abs(pct_diff) > 0.05: + print_autograd_prof_summary(autograd_prof_cpu, 'CPU', autograd_prof_sortby, autograd_prof_topk) + + print_autograd_prof_summary(autograd_prof_cuda, 'CUDA', autograd_prof_sortby, autograd_prof_topk) + +if __name__ == '__main__': + main() diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4feeda1e59fb9a5089f7df871d1c8b29a2cd3835 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/__init__.py @@ -0,0 +1,77 @@ +from torch.utils.data.dataloader import ( + _DatasetKind, + DataLoader, + default_collate, + default_convert, + get_worker_info, +) +from torch.utils.data.datapipes._decorator import ( + argument_validation, + functional_datapipe, + guaranteed_datapipes_determinism, + non_deterministic, + runtime_validation, + runtime_validation_disabled, +) +from torch.utils.data.datapipes.datapipe import ( + DataChunk, + DFIterDataPipe, + IterDataPipe, + MapDataPipe, +) +from torch.utils.data.dataset import ( + ChainDataset, + ConcatDataset, + Dataset, + IterableDataset, + random_split, + StackDataset, + Subset, + TensorDataset, +) +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, + SubsetRandomSampler, + WeightedRandomSampler, +) + + +__all__ = [ + "BatchSampler", + "ChainDataset", + "ConcatDataset", + "DFIterDataPipe", + "DataChunk", + "DataLoader", + "Dataset", + "DistributedSampler", + "IterDataPipe", + "IterableDataset", + "MapDataPipe", + "RandomSampler", + "Sampler", + "SequentialSampler", + "StackDataset", + "Subset", + "SubsetRandomSampler", + "TensorDataset", + "WeightedRandomSampler", + "_DatasetKind", + "argument_validation", + "default_collate", + "default_convert", + "functional_datapipe", + "get_worker_info", + "guaranteed_datapipes_determinism", + "non_deterministic", + "random_split", + "runtime_validation", + "runtime_validation_disabled", +] + +# Please keep this list sorted +assert __all__ == sorted(__all__) diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/backward_compatibility.py b/.venv/lib/python3.12/site-packages/torch/utils/data/backward_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f1c4e30ef720f676cf6581333cf3d48733e640 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/backward_compatibility.py @@ -0,0 +1,11 @@ +# mypy: allow-untyped-defs +from typing_extensions import deprecated as _deprecated + + +@_deprecated( + "Usage of `backward_compatibility.worker_init_fn` is deprecated " + "as `DataLoader` automatically applies sharding in every worker", + category=FutureWarning, +) +def worker_init_fn(worker_id): + pass diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py b/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..fc230a01a75c62bab6702f01b2e435fcd2c69a32 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py @@ -0,0 +1,1664 @@ +# mypy: allow-untyped-defs +r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" +from __future__ import annotations + +import functools +import itertools +import logging +import multiprocessing as python_multiprocessing +import os +import queue +import threading +import warnings +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Self + +import torch +import torch.distributed as dist +import torch.utils.data.graph_settings +from torch._utils import ExceptionWrapper +from torch.utils.data import _utils +from torch.utils.data.datapipes.datapipe import ( + _IterDataPipeSerializationWrapper, + _MapDataPipeSerializationWrapper, + IterDataPipe, + MapDataPipe, +) +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable + +__all__ = [ + "DataLoader", + "get_worker_info", + "default_collate", + "default_convert", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[list[_T]], Any] + + +# These functions used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate +default_convert = _utils.collate.default_convert + +get_worker_info = _utils.worker.get_worker_info + +logger = logging.getLogger(__name__) + + +class _DatasetKind: + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + + Used as sampler for :class:`~torch.utils.data.IterableDataset`. + """ + + def __iter__(self): + while True: + yield None + + +def _get_distributed_settings(): + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + else: + return 1, 0 + + +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): + global_worker_id = worker_id + info = torch.utils.data.get_worker_info() + assert info is not None + total_workers = info.num_workers + datapipe = info.dataset + assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) + # To distribute elements across distributed process evenly, we should shard data on distributed + # processes first then shard on worker processes + total_workers *= world_size + global_worker_id = global_worker_id * world_size + rank_id + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + datapipe, total_workers, global_worker_id + ) + if worker_init_fn is not None: + worker_init_fn(worker_id) + + +def _share_dist_seed(generator, pg): + _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) + if isinstance(pg, dist.ProcessGroup): + dist.broadcast(_shared_seed, src=0, group=pg) + return _shared_seed.item() + + +class DataLoader(Generic[_T_co]): + r""" + Data loader combines a dataset and a sampler, and provides an iterable over the given dataset. + + The :class:`~torch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`torch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If + ``None``, the default + `multiprocessing context `_ # noqa: D401 + of your operating system will + be used. (default: ``None``) + generator (torch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + ``base_seed`` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise, if value of ``num_workers > 0`` default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shut down + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the device to :attr:`pin_memory` on if ``pin_memory`` is + ``True``. If not given, the current :ref:`accelerator` will be the + default. This argument is discouraged and subject to deprecated. + in_order (bool, optional): If ``False``, the data loader will not enforce that batches + are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess PyTorch can make because PyTorch + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~torch.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + + .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data + distribution being fed to the trainer in cases with imbalanced data. + """ + + dataset: Dataset[_T_co] + batch_size: Optional[int] + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + sampler: Union[Sampler, Iterable] + pin_memory_device: str + prefetch_factor: Optional[int] + _iterator: Optional[_BaseDataLoaderIter] + __initialized = False + + def __init__( + self, + dataset: Dataset[_T_co], + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = None, + sampler: Union[Sampler, Iterable, None] = None, + batch_sampler: Union[Sampler[list], Iterable[list], None] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + in_order: bool = True, + ) -> None: + torch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor is not None: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." + ) + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError("prefetch_factor option should be non-negative") + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + self.in_order = in_order + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler + if isinstance(self.dataset, IterDataPipe): + self.dataset = _IterDataPipeSerializationWrapper(self.dataset) + elif isinstance(self.dataset, MapDataPipe): + self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + if isinstance(dataset, IterDataPipe): + if shuffle is not None: + dataset = torch.utils.data.graph_settings.apply_shuffle_settings( + dataset, shuffle=shuffle + ) + # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. + elif shuffle not in {False, None}: + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}" + ) + + if sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}" + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + f"batch_sampler option, but got batch_sampler={batch_sampler}" + ) + else: + shuffle = bool(shuffle) + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + self._iterator = None + + self.check_worker_number_rationality() + + torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] + + def _get_iterator(self) -> _BaseDataLoaderIter: + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if isinstance(multiprocessing_context, str): + valid_start_methods = torch.multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + "multiprocessing_context option " + f"should specify a valid start method in {valid_start_methods!r}, but got " + f"multiprocessing_context={multiprocessing_context!r}" + ) + multiprocessing_context = torch.multiprocessing.get_context( + multiprocessing_context + ) + + if not isinstance( + multiprocessing_context, python_multiprocessing.context.BaseContext + ): + raise TypeError( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + f"multiprocessing_context={multiprocessing_context}" + ) + else: + raise ValueError( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + f"num_workers={self.num_workers}" + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + f"{attr} attribute should not be set after {self.__class__.__name__} is initialized" + ) + + super().__setattr__(attr, val) + + def __iter__(self) -> _BaseDataLoaderIter: + # When using a single worker the returned iterator should be + # created everytime to avoid resetting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] + if ( + self.batch_size is not None + ): # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `torch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + suggested_max_worker_msg = ( + ( + ( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create." + ).format( + num_worker_suggest, + ( + "" + if cpuset_checked + else " (`cpuset` is not taken into account)" + ), + ) + ) + if num_worker_suggest is not None + else ( + "DataLoader is not able to compute a suggested max number of worker in current system." + ) + ) + + warn_msg = ( + f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary." + ) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, "sched_getaffinity"): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satisfy mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + + +class _BaseDataLoaderIter: + def __init__(self, loader: DataLoader) -> None: + self._dataset = loader.dataset + self._shared_seed = None + self._pg = None + if isinstance(self._dataset, IterDataPipe): + if dist.is_available() and dist.is_initialized(): + self._pg = dist.new_group(backend="gloo") + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + ws, rank = _get_distributed_settings() + self._world_size = ws + self._rank = rank + # If pin_memory_device not set, default behaviour is current accelerator. + # If pin_memory_device is set but pin_memory is not set, the default + # behaviour false. + if len(loader.pin_memory_device) == 0: + if loader.pin_memory and not torch.accelerator.is_available(): + warn_msg = ( + "'pin_memory' argument is set as true but no accelerator is found, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory and torch.accelerator.is_available() + self._pin_memory_device = None + # Currently, pin_memory would raise error on the MPS backend (see + # https://github.com/pytorch/pytorch/issues/86060), so forcibly + # disable pin_memory on MPS. Remove this restriction once pinned + # memory allocation for MPS is fixed. + if ( + self._pin_memory + and (acc := torch.accelerator.current_accelerator()) is not None + and acc.type == "mps" + ): + self._pin_memory = False + warn_msg = ( + "'pin_memory' argument is set as true but not supported on MPS now, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg) + else: + if not loader.pin_memory: + warn_msg = ( + "'pin_memory_device' is set but 'pin_memory' argument is not set, " + "then device pinned memory won't be used." + "please set 'pin_memory' to true, if you need to use the device pin memory" + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = ( + torch.empty((), dtype=torch.int64) + .random_(generator=loader.generator) + .item() + ) + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" + + def __iter__(self) -> Self: + return self + + def _reset(self, loader, first_iter=False): + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + if isinstance(self._dataset, IterDataPipe): + self._shared_seed = _share_dist_seed(loader.generator, self._pg) + shared_rng = torch.Generator() + shared_rng.manual_seed(self._shared_seed) + self._dataset = torch.utils.data.graph_settings.apply_random_seed( + self._dataset, shared_rng + ) + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self) -> Any: + with torch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + # TODO(https://github.com/pytorch/pytorch/issues/76750) + self._reset() # type: ignore[call-arg] + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}" + f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. " + ) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super().__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Taking care of distributed sharding + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # For BC, use default SHARDING_PRIORITIES + torch.utils.data.graph_settings.apply_sharding( + self._dataset, self._world_size, self._rank + ) + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler.""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may already be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. + # + # Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: + # + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. + # + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # *exits*. + # + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `worker_result_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super().__init__(loader) + + self._prefetch_factor = loader.prefetch_factor + self._in_order = loader.in_order + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = torch.multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + current_device = -1 + if self._pin_memory_device == "cuda": + current_device = torch.cuda.current_device() + elif self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() + elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr( + torch, torch._C._get_privateuse1_backend_name() + ) + current_device = custom_device_mod.current_device() + elif self._pin_memory_device is None: + current_device = torch.accelerator.current_device_index() + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue # type: ignore[assignment] + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def _reset(self, loader, first_iter=False): + super()._reset(loader, first_iter) + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + # Not that this indicates that a worker still has work to do *for this epoch*. + # It does not mean that a worker is dead. In case of `_persistent_workers`, + # the worker will be reset to available in the next epoch. + self._workers_status = [True for i in range(self._num_workers)] + # A list of integers representing how many tasks are outstanding for each worker + # Incremented when a task is dispatched to the worker + # Decremented when that data has been given to the main thread + # Each worker should have at most self._prefetch_factor tasks outstanding + self._workers_num_tasks = [0 for i in range(self._num_workers)] + # Reset the worker queue cycle so it resumes next epoch at worker 0 + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + # We resume the prefetching in case it was enabled + if not first_iter: + for idx in range(self._num_workers): + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) + resume_iteration_cnt = self._num_workers + while resume_iteration_cnt > 0: + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None + resume_iteration_cnt -= 1 + # prime the prefetch loop + for _ in range(self._prefetch_factor * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._mark_worker_as_unavailable(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError( + f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" + ) from e + if isinstance(e, queue.Empty): + return (False, None) + + import errno + import tempfile + + try: + # Raise an exception if we are this close to the FDs limit. + # Apparently, trying to open only one file is not a sufficient + # test. + # See NOTE [ DataLoader on Linux and open files limit ] + fds_limit_margin = 10 + [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + except OSError as e: + if e.errno == errno.EMFILE: + raise RuntimeError( + "Too many open files. Communication with the" + " workers is no longer possible. Please increase the" + " limit using `ulimit -n` in the shell or change the" + " sharing strategy by calling" + " `torch.multiprocessing.set_sharing_strategy('file_system')`" + " at the beginning of your code" + ) from None + raise + + # NOTE [ DataLoader on Linux and open files limit ] + # + # On Linux when DataLoader is used with multiprocessing we pass the data between + # the root process and the workers through SHM files. We remove those files from + # the filesystem as soon as they are created and keep them alive by + # passing around their file descriptors through AF_UNIX sockets. (See + # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in + # the wiki (https://github.com/pytorch/pytorch/wiki).) + # + # This sometimes leads us to exceeding the open files limit. When that happens, + # and the offending file descriptor is coming over a socket, the `socket` Python + # package silently strips the file descriptor from the message, setting only the + # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that + # it _indicates that some control data were discarded due to lack of space in + # the buffer for ancillary data_). This might reflect the C implementation of + # AF_UNIX sockets. + # + # This behaviour can be reproduced with the script and instructions at the + # bottom of this note. + # + # When that happens, the standard Python `multiprocessing` (and not + # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` + # + # Sometimes, instead of the FD being stripped, you may get an `OSError: + # Too many open files`, both in the script below and in DataLoader. However, + # this is rare and seems to be nondeterministic. + # + # + # #!/usr/bin/env python3 + # import sys + # import socket + # import os + # import array + # import shutil + # import socket + # + # + # if len(sys.argv) != 4: + # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") + # sys.exit(1) + # + # if __name__ == '__main__': + # dirname = sys.argv[1] + # sock_path = dirname + "/sock" + # iterations = int(sys.argv[2]) + # def dummy_path(i): + # return dirname + "/" + str(i) + ".dummy" + # + # + # if sys.argv[3] == 'send': + # while not os.path.exists(sock_path): + # pass + # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # client.connect(sock_path) + # for i in range(iterations): + # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) + # ancdata = array.array('i', [fd]) + # msg = bytes([i % 256]) + # print("Sending fd ", fd, " (iteration #", i, ")") + # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) + # + # + # else: + # assert sys.argv[3] == 'recv' + # + # if os.path.exists(dirname): + # raise Exception("Directory exists") + # + # os.mkdir(dirname) + # + # print("Opening socket...") + # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # server.bind(sock_path) + # + # print("Listening...") + # for i in range(iterations): + # a = array.array('i') + # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) + # assert(len(ancdata) == 1) + # cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # a.frombytes(cmsg_data) + # print("Received fd ", a[0], " (iteration #", i, ")") + # + # shutil.rmtree(dirname) + # + # Steps to reproduce: + # + # 1. Run two shells and set lower file descriptor limit in the receiving one: + # (shell1) ulimit -n 1020 + # (shell2) ulimit -n 1022 + # + # 2. Run the script above with the `recv` option in the first shell + # (shell1) ./test_socket.py sock_tmp 1017 recv + # + # 3. Run the script with the `send` option in the second shell: + # (shell2) ./test_socket.py sock_tmp 1017 send + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError( + f"DataLoader timed out after {self._timeout} seconds" + ) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info.get(self._rcvd_idx, None) + if info: + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + if not self._persistent_workers: + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + worker_id, data = self._task_info.pop(self._rcvd_idx) + self._rcvd_idx += 1 + return self._process_data(data, worker_id) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + if self._persistent_workers: + self._workers_status[data.worker_id] = False + else: + self._mark_worker_as_unavailable(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + if not self._in_order: + # don't store it for later, process now + # delete from self._task_info immediately + # this keeps the object size manageable + worker_id = self._task_info.pop(idx)[0] + return self._process_data(data, worker_id) + # store out-of-order samples + self._task_info[idx] += (data,) + else: + worker_id = self._task_info.pop(idx)[0] + self._rcvd_idx += 1 + return self._process_data(data, worker_id) + + def _try_put_index(self): + max_tasks = self._prefetch_factor * self._num_workers + assert self._tasks_outstanding < max_tasks + + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + if self._in_order: + break + elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum( + self._workers_status + ): + # when self._in_order is False, distribute work to a worker if it has capacity + # _workers_status is updated only in this thread, so the sum is guaranteed > 0 + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined] + self._task_info[self._send_idx] = (worker_queue_idx,) + self._workers_num_tasks[worker_queue_idx] += 1 + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data, worker_idx): + self._workers_num_tasks[worker_idx] -= 1 + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + # Mark a worker as having finished its work e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] or ( + self._persistent_workers and shutdown + ) + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + + self._workers_status[worker_id] = False + + assert self._workers_done_event.is_set() == shutdown + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + # If we are using workers_status with persistent_workers + # we have to shut it down because the worker is paused + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() + + # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` + @staticmethod + def _clean_up_worker(w): + try: + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + finally: + if w.is_alive(): + w.terminate() + + def __del__(self): + self._shutdown_workers() diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/dataset.py b/.venv/lib/python3.12/site-packages/torch/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d0234c553ce683741a03b4a5a73e226f492d2e44 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/dataset.py @@ -0,0 +1,485 @@ +# mypy: allow-untyped-defs +import bisect +import itertools +import math +import warnings +from collections.abc import Sequence + +# UP006 wants 'Iterable' to be imported from collections.abc but it needs to +# stay from typing for now due to BC concerns. In particular several internal +# targets fail to typecheck with: +# TypeError: Cannot create a consistent method resolution order (MRO) for +# bases Iterable, Generic +from typing import cast, Generic, Iterable, Optional, TypeVar, Union # noqa: UP035 +from typing_extensions import deprecated + +# No 'default_generator' in torch/__init__.pyi +from torch import default_generator, Generator, randperm, Tensor + + +__all__ = [ + "Dataset", + "IterableDataset", + "TensorDataset", + "StackDataset", + "ConcatDataset", + "ChainDataset", + "Subset", + "random_split", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_T_dict = dict[str, _T_co] +_T_tuple = tuple[_T_co, ...] +_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict) + + +class Dataset(Generic[_T_co]): + r"""An abstract class representing a :class:`Dataset`. + + All datasets that represent a map from keys to data samples should subclass + it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given key. Subclasses could also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~torch.utils.data.Sampler` implementations and the default options + of :class:`~torch.utils.data.DataLoader`. Subclasses could also + optionally implement :meth:`__getitems__`, for speedup batched samples + loading. This method accepts list of indices of samples of batch and returns + list of samples. + + .. note:: + :class:`~torch.utils.data.DataLoader` by default constructs an index + sampler that yields integral indices. To make it work with a map-style + dataset with non-integral indices/keys, a custom sampler must be provided. + """ + + def __getitem__(self, index) -> _T_co: + raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") + + # def __getitems__(self, indices: List) -> List[_T_co]: + # Not implemented to prevent false-positives in fetcher check in + # torch.utils.data._utils.fetch._MapDatasetFetcher + + def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": + return ConcatDataset([self, other]) + + # No `def __len__(self)` default? + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # in pytorch/torch/utils/data/sampler.py + + +class IterableDataset(Dataset[_T_co], Iterable[_T_co]): + r"""An iterable Dataset. + + All datasets that represent an iterable of data samples should subclass it. + Such form of datasets is particularly useful when data come from a stream. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this dataset. + + When a subclass is used with :class:`~torch.utils.data.DataLoader`, each + item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the dataset object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Example 1: splitting workload across all workers in :meth:`__iter__`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> # xdoctest: +SKIP("Fails on MacOS12") + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... worker_info = torch.utils.data.get_worker_info() + ... if worker_info is None: # single-process data loading, return the full iterator + ... iter_start = self.start + ... iter_end = self.end + ... else: # in a worker process + ... # split workload + ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... iter_start = self.start + worker_id * per_worker + ... iter_end = min(iter_start + per_worker, self.end) + ... return iter(range(iter_start, iter_end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [tensor([3]), tensor([4]), tensor([5]), tensor([6])] + + >>> # xdoctest: +REQUIRES(POSIX) + >>> # Multi-process loading with two worker processes + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + >>> # With even more workers + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> class MyIterableDataset(torch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + >>> + >>> # Directly doing multi-process loading yields duplicate data + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) + [3, 3, 4, 4, 5, 5, 6, 6] + + >>> # Define a `worker_init_fn` that configures each dataset copy differently + >>> def worker_init_fn(worker_id): + ... worker_info = torch.utils.data.get_worker_info() + ... dataset = worker_info.dataset # the dataset copy in this worker process + ... overall_start = dataset.start + ... overall_end = dataset.end + ... # configure the dataset to only process the split workload + ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... dataset.start = overall_start + worker_id * per_worker + ... dataset.end = min(dataset.start + per_worker, overall_end) + ... + + >>> # Mult-process loading with the custom `worker_init_fn` + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) + [3, 5, 4, 6] + + >>> # With even more workers + >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) + [3, 4, 5, 6] + """ + + def __add__(self, other: Dataset[_T_co]): + return ChainDataset([self, other]) + + # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + + +class TensorDataset(Dataset[tuple[Tensor, ...]]): + r"""Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Args: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + + tensors: tuple[Tensor, ...] + + def __init__(self, *tensors: Tensor) -> None: + assert all( + tensors[0].size(0) == tensor.size(0) for tensor in tensors + ), "Size mismatch between tensors" + self.tensors = tensors + + def __getitem__(self, index): + return tuple(tensor[index] for tensor in self.tensors) + + def __len__(self): + return self.tensors[0].size(0) + + +class StackDataset(Dataset[_T_stack]): + r"""Dataset as a stacking of multiple datasets. + + This class is useful to assemble different parts of complex input data, given as datasets. + + Example: + >>> # xdoctest: +SKIP + >>> images = ImageDataset() + >>> texts = TextDataset() + >>> tuple_stack = StackDataset(images, texts) + >>> tuple_stack[0] == (images[0], texts[0]) + >>> dict_stack = StackDataset(image=images, text=texts) + >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + + Args: + *args (Dataset): Datasets for stacking returned as tuple. + **kwargs (Dataset): Datasets for stacking returned as dict. + """ + + datasets: Union[tuple, dict] + + def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: + if args: + if kwargs: + raise ValueError( + "Supported either ``tuple``- (via ``args``) or" + "``dict``- (via ``kwargs``) like input/output, but both types are given." + ) + self._length = len(args[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = args + elif kwargs: + tmp = list(kwargs.values()) + self._length = len(tmp[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = kwargs + else: + raise ValueError("At least one dataset should be passed") + + def __getitem__(self, index): + if isinstance(self.datasets, dict): + return {k: dataset[index] for k, dataset in self.datasets.items()} + return tuple(dataset[index] for dataset in self.datasets) + + def __getitems__(self, indices: list): + # add batched sampling support when parent datasets supports it. + if isinstance(self.datasets, dict): + dict_batch: list[_T_dict] = [{} for _ in indices] + for k, dataset in self.datasets.items(): + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, d_sample in zip(items, dict_batch): + d_sample[k] = data + else: + for idx, d_sample in zip(indices, dict_batch): + d_sample[k] = dataset[idx] + return dict_batch + + # tuple data + list_batch: list[list] = [[] for _ in indices] + for dataset in self.datasets: + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, t_sample in zip(items, list_batch): + t_sample.append(data) + else: + for idx, t_sample in zip(indices, list_batch): + t_sample.append(dataset[idx]) + tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch] + return tuple_batch + + def __len__(self): + return self._length + + +class ConcatDataset(Dataset[_T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + + datasets: list[Dataset[_T_co]] + cumulative_sizes: list[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = list(datasets) + assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + @deprecated( + "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", + category=FutureWarning, + ) + def cummulative_sizes(self): + return self.cumulative_sizes + + +class ChainDataset(IterableDataset): + r"""Dataset for chaining multiple :class:`IterableDataset` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = datasets + + def __iter__(self): + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + yield from d + + def __len__(self): + total = 0 + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + total += len(d) # type: ignore[arg-type] + return total + + +class Subset(Dataset[_T_co]): + r""" + Subset of a dataset at specified indices. + + Args: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + + dataset: Dataset[_T_co] + indices: Sequence[int] + + def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.dataset[[self.indices[i] for i in idx]] + return self.dataset[self.indices[idx]] + + def __getitems__(self, indices: list[int]) -> list[_T_co]: + # add batched sampling support when parent dataset supports it. + # see torch.utils.data._utils.fetch._MapDatasetFetcher + if callable(getattr(self.dataset, "__getitems__", None)): + return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] + else: + return [self.dataset[self.indices[idx]] for idx in indices] + + def __len__(self): + return len(self.indices) + + +def random_split( + dataset: Dataset[_T], + lengths: Sequence[Union[int, float]], + generator: Optional[Generator] = default_generator, +) -> list[Subset[_T]]: + r""" + Randomly split a dataset into non-overlapping new datasets of given lengths. + + If a list of fractions that sum up to 1 is given, + the lengths will be computed automatically as + floor(frac * len(dataset)) for each fraction provided. + + After computing the lengths, if there are any remainders, 1 count will be + distributed in round-robin fashion to the lengths + until there are no remainders left. + + Optionally fix the generator for reproducible results, e.g.: + + Example: + >>> # xdoctest: +SKIP + >>> generator1 = torch.Generator().manual_seed(42) + >>> generator2 = torch.Generator().manual_seed(42) + >>> random_split(range(10), [3, 7], generator=generator1) + >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) + + Args: + dataset (Dataset): Dataset to be split + lengths (sequence): lengths or fractions of splits to be produced + generator (Generator): Generator used for the random permutation. + """ + if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: + subset_lengths: list[int] = [] + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int( + math.floor(len(dataset) * frac) # type: ignore[arg-type] + ) + subset_lengths.append(n_items_in_split) + remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn( + f"Length of split at index {i} is 0. " + f"This might result in an empty dataset." + ) + + # Cannot verify that dataset is Sized + if sum(lengths) != len(dataset): # type: ignore[arg-type] + raise ValueError( + "Sum of input lengths does not equal the length of the input dataset!" + ) + + indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] + lengths = cast(Sequence[int], lengths) + return [ + Subset(dataset, indices[offset - length : offset]) + for offset, length in zip(itertools.accumulate(lengths), lengths) + ] diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/distributed.py b/.venv/lib/python3.12/site-packages/torch/utils/data/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..949e3e0c23b409690ceb0dfa21f16b54c8493320 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/distributed.py @@ -0,0 +1,150 @@ +import math +from collections.abc import Iterator +from typing import Optional, TypeVar + +import torch +import torch.distributed as dist +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import Sampler + + +__all__ = ["DistributedSampler"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class DistributedSampler(Sampler[_T_co]): + r"""Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a + :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size and that any instance of it always + returns the same elements in the same order. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> # xdoctest: +SKIP + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/graph.py b/.venv/lib/python3.12/site-packages/torch/utils/data/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..26a4eae6d18c32954d784310b5e87834332662ae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/graph.py @@ -0,0 +1,161 @@ +# mypy: allow-untyped-defs +import io +import pickle +import warnings +from collections.abc import Collection +from typing import Optional, Union + +from torch.utils._import_utils import dill_available +from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe + + +__all__ = ["traverse", "traverse_dps"] + +DataPipe = Union[IterDataPipe, MapDataPipe] +DataPipeGraph = dict[int, tuple[DataPipe, "DataPipeGraph"]] + + +def _stub_unpickler(): + return "STUB" + + +# TODO(VitalyFedyunin): Make sure it works without dill module installed +def _list_connected_datapipes( + scan_obj: DataPipe, only_datapipe: bool, cache: set[int] +) -> list[DataPipe]: + f = io.BytesIO() + p = pickle.Pickler( + f + ) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is + if dill_available(): + from dill import Pickler as dill_Pickler + + d = dill_Pickler(f) + else: + d = None + + captured_connections = [] + + def getstate_hook(ori_state): + state = None + if isinstance(ori_state, dict): + state = {} + for k, v in ori_state.items(): + if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): + state[k] = v + elif isinstance(ori_state, (tuple, list)): + state = [] # type: ignore[assignment] + for v in ori_state: + if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): + state.append(v) # type: ignore[attr-defined] + elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)): + state = ori_state # type: ignore[assignment] + return state + + def reduce_hook(obj): + if obj == scan_obj or id(obj) in cache: + raise NotImplementedError + else: + captured_connections.append(obj) + # Adding id to remove duplicate DataPipe serialized at the same level + cache.add(id(obj)) + return _stub_unpickler, () + + datapipe_classes: tuple[type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment] + + try: + for cls in datapipe_classes: + cls.set_reduce_ex_hook(reduce_hook) + if only_datapipe: + cls.set_getstate_hook(getstate_hook) + try: + p.dump(scan_obj) + except (pickle.PickleError, AttributeError, TypeError): + if dill_available(): + d.dump(scan_obj) + else: + raise + finally: + for cls in datapipe_classes: + cls.set_reduce_ex_hook(None) + if only_datapipe: + cls.set_getstate_hook(None) + if dill_available(): + from dill import extend as dill_extend + + dill_extend(False) # Undo change to dispatch table + return captured_connections + + +def traverse_dps(datapipe: DataPipe) -> DataPipeGraph: + r""" + Traverse the DataPipes and their attributes to extract the DataPipe graph. + + This only looks into the attribute from each DataPipe that is either a + DataPipe and a Python collection object such as ``list``, ``tuple``, + ``set`` and ``dict``. + + Args: + datapipe: the end DataPipe of the graph + Returns: + A graph represented as a nested dictionary, where keys are ids of DataPipe instances + and values are tuples of DataPipe instance and the sub-graph + """ + cache: set[int] = set() + return _traverse_helper(datapipe, only_datapipe=True, cache=cache) + + +def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph: + r""" + Traverse the DataPipes and their attributes to extract the DataPipe graph. + + [Deprecated] + When ``only_dataPipe`` is specified as ``True``, it would only look into the + attribute from each DataPipe that is either a DataPipe and a Python collection object + such as ``list``, ``tuple``, ``set`` and ``dict``. + + Note: + This function is deprecated. Please use `traverse_dps` instead. + + Args: + datapipe: the end DataPipe of the graph + only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed. + This argument is deprecating and will be removed after the next release. + Returns: + A graph represented as a nested dictionary, where keys are ids of DataPipe instances + and values are tuples of DataPipe instance and the sub-graph + """ + msg = ( + "`traverse` function and will be removed after 1.13. " + "Please use `traverse_dps` instead." + ) + if not only_datapipe: + msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`." + warnings.warn(msg, FutureWarning) + if only_datapipe is None: + only_datapipe = False + cache: set[int] = set() + return _traverse_helper(datapipe, only_datapipe, cache) + + +# Add cache here to prevent infinite recursion on DataPipe +def _traverse_helper( + datapipe: DataPipe, only_datapipe: bool, cache: set[int] +) -> DataPipeGraph: + if not isinstance(datapipe, (IterDataPipe, MapDataPipe)): + raise RuntimeError( + f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found" + ) + + dp_id = id(datapipe) + if dp_id in cache: + return {} + cache.add(dp_id) + # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths + items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy()) + d: DataPipeGraph = {dp_id: (datapipe, {})} + for item in items: + # Using cache.copy() here is to prevent recursion on a single path rather than global graph + # Single DataPipe can present multiple times in different paths in graph + d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy())) + return d diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/graph_settings.py b/.venv/lib/python3.12/site-packages/torch/utils/data/graph_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc16c86b0f3d27ab05dae069bc3ed7a87b297ab --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/graph_settings.py @@ -0,0 +1,174 @@ +# mypy: allow-untyped-defs +import inspect +import warnings +from typing import Any, Optional +from typing_extensions import deprecated + +import torch +from torch.utils.data.datapipes.iter.sharding import ( + _ShardingIterDataPipe, + SHARDING_PRIORITIES, +) +from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps + + +__all__ = [ + "apply_random_seed", + "apply_sharding", + "apply_shuffle_seed", + "apply_shuffle_settings", + "get_all_graph_pipes", +] + + +def get_all_graph_pipes(graph: DataPipeGraph) -> list[DataPipe]: + return _get_all_graph_pipes_helper(graph, set()) + + +def _get_all_graph_pipes_helper( + graph: DataPipeGraph, id_cache: set[int] +) -> list[DataPipe]: + results: list[DataPipe] = [] + for dp_id, (datapipe, sub_graph) in graph.items(): + if dp_id in id_cache: + continue + id_cache.add(dp_id) + results.append(datapipe) + results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache)) + return results + + +def _is_sharding_datapipe(datapipe: DataPipe) -> bool: + return isinstance(datapipe, _ShardingIterDataPipe) or ( + hasattr(datapipe, "apply_sharding") + and inspect.ismethod(datapipe.apply_sharding) + ) + + +def apply_sharding( + datapipe: DataPipe, + num_of_instances: int, + instance_id: int, + sharding_group=SHARDING_PRIORITIES.DEFAULT, +) -> DataPipe: + r""" + Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``. + + RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch. + """ + graph = traverse_dps(datapipe) + + def _helper(graph, prev_applied=None): + for dp, sub_graph in graph.values(): + applied = None + if _is_sharding_datapipe(dp): + if prev_applied is not None: + raise RuntimeError( + "Sharding twice on a single pipeline is likely unintended and will cause data loss. " + f"Sharding already applied to {prev_applied} while trying to apply to {dp}" + ) + # For BC, only provide sharding_group if accepted + sig = inspect.signature(dp.apply_sharding) + if len(sig.parameters) < 3: + dp.apply_sharding(num_of_instances, instance_id) + else: + dp.apply_sharding( + num_of_instances, instance_id, sharding_group=sharding_group + ) + applied = dp + if applied is None: + applied = prev_applied + _helper(sub_graph, applied) + + _helper(graph) + + return datapipe + + +def _is_shuffle_datapipe(datapipe: DataPipe) -> bool: + return ( + hasattr(datapipe, "set_shuffle") + and hasattr(datapipe, "set_seed") + and inspect.ismethod(datapipe.set_shuffle) + and inspect.ismethod(datapipe.set_seed) + ) + + +def apply_shuffle_settings( + datapipe: DataPipe, shuffle: Optional[bool] = None +) -> DataPipe: + r""" + Traverse the graph of ``DataPipes`` to find and set shuffle attribute. + + Apply the method to each `DataPipe` that has APIs of ``set_shuffle`` + and ``set_seed``. + + Args: + datapipe: DataPipe that needs to set shuffle attribute + shuffle: Shuffle option (default: ``None`` and no-op to the graph) + """ + if shuffle is None: + return datapipe + + graph = traverse_dps(datapipe) + all_pipes = get_all_graph_pipes(graph) + shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)] + if not shufflers and shuffle: + warnings.warn( + "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. " + "Be aware that the default buffer size might not be sufficient for your task." + ) + datapipe = datapipe.shuffle() + shufflers = [ + datapipe, + ] + + for shuffler in shufflers: + shuffler.set_shuffle(shuffle) + + return datapipe + + +@deprecated( + "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. " + "Please use `apply_random_seed` instead.", + category=FutureWarning, +) +def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe: + return apply_random_seed(datapipe, rng) + + +def _is_random_datapipe(datapipe: DataPipe) -> bool: + return hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed) + + +def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe: + r""" + Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``. + + Then set the random seed based on the provided RNG to those ``DataPipe``. + + Args: + datapipe: DataPipe that needs to set randomness + rng: Random number generator to generate random seeds + """ + graph = traverse_dps(datapipe) + all_pipes = get_all_graph_pipes(graph) + # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once. + # And, `id` is used in case of unhashable DataPipe + cache = set() + random_datapipes = [] + for pipe in all_pipes: + if id(pipe) in cache: + continue + if _is_random_datapipe(pipe): + random_datapipes.append(pipe) + cache.add(id(pipe)) + + for pipe in random_datapipes: + random_seed = int( + torch.empty((), dtype=torch.int64).random_(generator=rng).item() + ) + pipe.set_seed(random_seed) + + return datapipe diff --git a/.venv/lib/python3.12/site-packages/torch/utils/data/sampler.py b/.venv/lib/python3.12/site-packages/torch/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..e998d545ac2315d4e39678293a69ed0f195c081e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/data/sampler.py @@ -0,0 +1,348 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Iterable, Iterator, Sequence, Sized +from typing import Generic, Optional, TypeVar, Union + +import torch + + +__all__ = [ + "BatchSampler", + "RandomSampler", + "Sampler", + "SequentialSampler", + "SubsetRandomSampler", + "WeightedRandomSampler", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class Sampler(Generic[_T_co]): + r"""Base class for all Samplers. + + Every Sampler subclass has to provide an :meth:`__iter__` method, providing a + way to iterate over indices or lists of indices (batches) of dataset elements, + and may provide a :meth:`__len__` method that returns the length of the returned iterators. + + Args: + data_source (Dataset): This argument is not used and will be removed in 2.2.0. + You may still have custom implementation that utilizes it. + + Example: + >>> # xdoctest: +SKIP + >>> class AccedingSequenceLengthSampler(Sampler[int]): + >>> def __init__(self, data: List[str]) -> None: + >>> self.data = data + >>> + >>> def __len__(self) -> int: + >>> return len(self.data) + >>> + >>> def __iter__(self) -> Iterator[int]: + >>> sizes = torch.tensor([len(x) for x in self.data]) + >>> yield from torch.argsort(sizes).tolist() + >>> + >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): + >>> def __init__(self, data: List[str], batch_size: int) -> None: + >>> self.data = data + >>> self.batch_size = batch_size + >>> + >>> def __len__(self) -> int: + >>> return (len(self.data) + self.batch_size - 1) // self.batch_size + >>> + >>> def __iter__(self) -> Iterator[List[int]]: + >>> sizes = torch.tensor([len(x) for x in self.data]) + >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): + >>> yield batch.tolist() + + .. note:: The :meth:`__len__` method isn't strictly required by + :class:`~torch.utils.data.DataLoader`, but is expected in any + calculation involving the length of a :class:`~torch.utils.data.DataLoader`. + """ + + def __init__(self, data_source: Optional[Sized] = None) -> None: + if data_source is not None: + import warnings + + warnings.warn( + "`data_source` argument is not used and will be removed in 2.2.0." + "You may still have custom implementation that utilizes it." + ) + + def __iter__(self) -> Iterator[_T_co]: + raise NotImplementedError + + # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # + # Many times we have an abstract class representing a collection/iterable of + # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally + # implementing a `__len__` method. In such cases, we must make sure to not + # provide a default implementation, because both straightforward default + # implementations have their issues: + # + # + `return NotImplemented`: + # Calling `len(subclass_instance)` raises: + # TypeError: 'NotImplementedType' object cannot be interpreted as an integer + # + # + `raise NotImplementedError`: + # This prevents triggering some fallback behavior. E.g., the built-in + # `list(X)` tries to call `len(X)` first, and executes a different code + # path if the method is not found or `NotImplemented` is returned, while + # raising a `NotImplementedError` will propagate and make the call fail + # where it could have used `__iter__` to complete the call. + # + # Thus, the only two sensible things to do are + # + # + **not** provide a default `__len__`. + # + # + raise a `TypeError` instead, which is what Python uses when users call + # a method that is not defined on an object. + # (@ssnl verifies that this works on at least Python 3.7.) + + +class SequentialSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Dataset): dataset to sample from + """ + + data_source: Sized + + def __init__(self, data_source: Sized) -> None: + self.data_source = data_source + + def __iter__(self) -> Iterator[int]: + return iter(range(len(self.data_source))) + + def __len__(self) -> int: + return len(self.data_source) + + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__( + self, + data_source: Sized, + replacement: bool = False, + num_samples: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if not isinstance(self.replacement, bool): + raise TypeError( + f"replacement should be a boolean value, but got replacement={self.replacement}" + ) + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" + ) + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint( + high=n, size=(32,), dtype=torch.int64, generator=generator + ).tolist() + yield from torch.randint( + high=n, + size=(self.num_samples % 32,), + dtype=torch.int64, + generator=generator, + ).tolist() + else: + for _ in range(self.num_samples // n): + yield from torch.randperm(n, generator=generator).tolist() + yield from torch.randperm(n, generator=generator).tolist()[ + : self.num_samples % n + ] + + def __len__(self) -> int: + return self.num_samples + + +class SubsetRandomSampler(Sampler[int]): + r"""Samples elements randomly from a given list of indices, without replacement. + + Args: + indices (sequence): a sequence of indices + generator (Generator): Generator used in sampling. + """ + + indices: Sequence[int] + + def __init__(self, indices: Sequence[int], generator=None) -> None: + self.indices = indices + self.generator = generator + + def __iter__(self) -> Iterator[int]: + for i in torch.randperm(len(self.indices), generator=self.generator).tolist(): + yield self.indices[i] + + def __len__(self) -> int: + return len(self.indices) + + +class WeightedRandomSampler(Sampler[int]): + r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). + + Args: + weights (sequence) : a sequence of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + generator (Generator): Generator used in sampling. + + Example: + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + [4, 4, 1, 4, 5] + >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + [0, 1, 4, 3, 2] + """ + + weights: torch.Tensor + num_samples: int + replacement: bool + + def __init__( + self, + weights: Sequence[float], + num_samples: int, + replacement: bool = True, + generator=None, + ) -> None: + if ( + not isinstance(num_samples, int) + or isinstance(num_samples, bool) + or num_samples <= 0 + ): + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={num_samples}" + ) + if not isinstance(replacement, bool): + raise ValueError( + f"replacement should be a boolean value, but got replacement={replacement}" + ) + + weights_tensor = torch.as_tensor(weights, dtype=torch.double) + if len(weights_tensor.shape) != 1: + raise ValueError( + "weights should be a 1d sequence but given " + f"weights have shape {tuple(weights_tensor.shape)}" + ) + + self.weights = weights_tensor + self.num_samples = num_samples + self.replacement = replacement + self.generator = generator + + def __iter__(self) -> Iterator[int]: + rand_tensor = torch.multinomial( + self.weights, self.num_samples, self.replacement, generator=self.generator + ) + yield from iter(rand_tensor.tolist()) + + def __len__(self) -> int: + return self.num_samples + + +class BatchSampler(Sampler[list[int]]): + r"""Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler or Iterable): Base sampler. Can be any iterable object + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__( + self, + sampler: Union[Sampler[int], Iterable[int]], + batch_size: int, + drop_last: bool, + ) -> None: + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if ( + not isinstance(batch_size, int) + or isinstance(batch_size, bool) + or batch_size <= 0 + ): + raise ValueError( + f"batch_size should be a positive integer value, but got batch_size={batch_size}" + ) + if not isinstance(drop_last, bool): + raise ValueError( + f"drop_last should be a boolean value, but got drop_last={drop_last}" + ) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self) -> Iterator[list[int]]: + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + sampler_iter = iter(self.sampler) + if self.drop_last: + # Create multiple references to the same iterator + args = [sampler_iter] * self.batch_size + for batch_droplast in zip(*args): + yield [*batch_droplast] + else: + batch = [*itertools.islice(sampler_iter, self.batch_size)] + while batch: + yield batch + batch = [*itertools.islice(sampler_iter, self.batch_size)] + + def __len__(self) -> int: + # Can only be called if self.sampler has __len__ implemented + # We cannot enforce this condition, so we turn off typechecking for the + # implementation below. + # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore[arg-type] + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] diff --git a/.venv/lib/python3.12/site-packages/torch/utils/hipify/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/hipify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58f3ace6c03d093337c9fa417ccbe8bc267b6c69 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/hipify/__init__.py @@ -0,0 +1 @@ +from .version import __version__ diff --git a/.venv/lib/python3.12/site-packages/torch/utils/hipify/constants.py b/.venv/lib/python3.12/site-packages/torch/utils/hipify/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a9053b261ad44d1ef8b8cbdf3a27da0306d92f36 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/hipify/constants.py @@ -0,0 +1,62 @@ +"""Constants for annotations in the mapping. + +The constants defined here are used to annotate the mapping tuples in cuda_to_hip_mappings.py. +They are based on +https://github.com/ROCm/HIPIFY/blob/master/src/Statistics.h +and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsupported +mapping. +""" + +CONV_VERSION = 0, +CONV_INIT = 1 +CONV_DEVICE = 2 +CONV_MEM = 3 +CONV_KERN = 4 +CONV_COORD_FUNC = 5 +CONV_MATH_FUNC = 6 +CONV_DEVICE_FUNC = 7 +CONV_SPECIAL_FUNC = 8 +CONV_STREAM = 9 +CONV_EVENT = 10 +CONV_OCCUPANCY = 11 +CONV_CONTEXT = 12 +CONV_PEER = 13 +CONV_MODULE = 14 +CONV_CACHE = 15 +CONV_EXEC = 16 +CONV_ERROR = 17 +CONV_DEF = 18 +CONV_TEX = 19 +CONV_GL = 20 +CONV_GRAPHICS = 21 +CONV_SURFACE = 22 +CONV_JIT = 23 +CONV_D3D9 = 24 +CONV_D3D10 = 25 +CONV_D3D11 = 26 +CONV_VDPAU = 27 +CONV_EGL = 28 +CONV_THREAD = 29 +CONV_OTHER = 30 +CONV_INCLUDE = 31 +CONV_INCLUDE_CUDA_MAIN_H = 32 +CONV_TYPE = 33 +CONV_LITERAL = 34 +CONV_NUMERIC_LITERAL = 35 +CONV_LAST = 36 + +API_DRIVER = 37 +API_RUNTIME = 38 +API_BLAS = 39 +API_SPECIAL = 40 +API_RAND = 41 +API_LAST = 42 +API_FFT = 43 +API_RTC = 44 +API_ROCTX = 45 + +HIP_UNSUPPORTED = 46 +API_PYTORCH = 1337 +API_CAFFE2 = 1338 +API_C10 = 1339 +API_ROCMSMI = 1340 diff --git a/.venv/lib/python3.12/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py b/.venv/lib/python3.12/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..a5145a2f4870a13fb6f6b5f21239615c29a51531 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/hipify/cuda_to_hip_mappings.py @@ -0,0 +1,8821 @@ +import collections +import os + +from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT, + API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, + API_SPECIAL, API_ROCMSMI, CONV_CACHE, CONV_CONTEXT, CONV_D3D9, + CONV_D3D10, CONV_D3D11, CONV_DEF, CONV_DEVICE, + CONV_DEVICE_FUNC, CONV_EGL, CONV_ERROR, CONV_EVENT, + CONV_EXEC, CONV_GL, CONV_GRAPHICS, CONV_INCLUDE, + CONV_INCLUDE_CUDA_MAIN_H, CONV_INIT, CONV_JIT, + CONV_MATH_FUNC, CONV_MEM, CONV_MODULE, + CONV_NUMERIC_LITERAL, CONV_OCCUPANCY, CONV_OTHER, + CONV_PEER, CONV_SPECIAL_FUNC, CONV_STREAM, + CONV_SURFACE, CONV_TEX, CONV_THREAD, CONV_TYPE, + CONV_VDPAU, CONV_VERSION, HIP_UNSUPPORTED) + +""" Mapping of CUDA functions, include files, constants, and types to ROCm/HIP equivalents +This closely follows the implementation in hipify-clang +https://github.com/ROCm/hip/blob/59071b895ed1c86d9698b4c859cefcdd5acda06f/hipify-clang/src/CUDA2HipMap.cpp +and its structure. +There are different maps for fundamental names, include files, identifies, sparse, and +PyTorch specific translations. +Each of the entries in these maps translates a CUDA string to a tuple containing the +ROCm/HIP string, a type and API annotation and - optionally - an annotation if it is not +supported in ROCm/HIP yet. +""" + +_IS_FBCODE = os.environ.get("IS_FBCODE", "0") == "1" + +# FBCODE compiles against rccl sources instead of an installed rccl package. +# The header location is src/rccl.h versus rccl/rccl.h, respectively. +_RCCL_HEADER = "" if _IS_FBCODE else "" + +# List of math functions that should be replaced inside device code only. +MATH_TRANSPILATIONS = collections.OrderedDict( + [ + ("std::max", ("::max")), + ("std::min", ("::min")), + ("std::ceil", ("::ceil")), + ("std::floor", ("::floor")), + ("std::exp", ("::exp")), + ("std::log", ("::log")), + ("std::pow", ("::pow")), + ("std::fabs", ("::fabs")), + ("std::fmod", ("::fmod")), + ("std::remainder", ("::remainder")), + ("std::frexp", ("::frexp")), + ] +) + +CUDA_TYPE_NAME_MAP = collections.OrderedDict( + [ + ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), + ("cudaError_t", ("hipError_t", CONV_TYPE, API_RUNTIME)), + ("cudaError", ("hipError_t", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ARRAY3D_DESCRIPTOR", + ("HIP_ARRAY3D_DESCRIPTOR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_ARRAY_DESCRIPTOR", ("HIP_ARRAY_DESCRIPTOR", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY2D", ("hip_Memcpy2D", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY3D", ("HIP_MEMCPY3D", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_MEMCPY3D_PEER", + ("HIP_MEMCPY3D_PEER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_POINTER_ATTRIBUTE_P2P_TOKENS", + ( + "HIP_POINTER_ATTRIBUTE_P2P_TOKENS", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_RESOURCE_DESC", + ("HIP_RESOURCE_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_RESOURCE_VIEW_DESC", + ("HIP_RESOURCE_VIEW_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUipcEventHandle", + ("hipIpcEventHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUipcMemHandle", ("hipIpcMemHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUaddress_mode", ("hipAddress_mode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUarray_cubemap_face", + ("hipArray_cubemap_face", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUarray_format", ("hipArray_format", CONV_TYPE, API_DRIVER)), + ("CUcomputemode", ("hipComputemode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmem_advise", ("hipMemAdvise", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUmem_range_attribute", + ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUctx_flags", ("hipCctx_flags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUdevice", ("hipDevice_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute_enum", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUpointer_attribute", ("hipPointer_attribute", CONV_TYPE, API_DRIVER)), + ("CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL", ("HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL", CONV_TYPE, API_DRIVER)), + ("CU_POINTER_ATTRIBUTE_BUFFER_ID", ("HIP_POINTER_ATTRIBUTE_BUFFER_ID", CONV_TYPE, API_DRIVER)), + ("CUdeviceptr", ("hipDeviceptr_t", CONV_TYPE, API_DRIVER)), + ("CUarray_st", ("hipArray", CONV_TYPE, API_DRIVER)), + ("CUarray", ("hipArray *", CONV_TYPE, API_DRIVER)), + ("CUdevprop_st", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUdevprop", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUfunction", ("hipFunction_t", CONV_TYPE, API_DRIVER)), + ( + "CUgraphicsResource", + ("hipGraphicsResource_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmipmappedArray", + ("hipMipmappedArray_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute_enum", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags_enum", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags_enum", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags_enum", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUfunc_cache_enum", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUfunc_cache", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUipcMem_flags", ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUipcMem_flags_enum", + ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_cacheMode", ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_cacheMode_enum", + ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_fallback", ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_fallback_enum", + ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_option", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_option_enum", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_target", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjit_target_enum", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjitInputType", ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjitInputType_enum", + ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUlimit", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ("CUlimit_enum", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ( + "CUmemAttach_flags", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmemAttach_flags_enum", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUmemorytype", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmemorytype_enum", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUresourcetype", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUresourcetype_enum", + ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUresourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUresourceViewFormat_enum", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUsharedconfig", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUsharedconfig_enum", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUcontext", ("hipCtx_t", CONV_TYPE, API_DRIVER)), + ("CUmodule", ("hipModule_t", CONV_TYPE, API_DRIVER)), + ("CUstream", ("hipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstream_st", ("ihipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstreamCallback", ("hipStreamCallback_t", CONV_TYPE, API_DRIVER)), + ("CUsurfObject", ("hipSurfaceObject", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUsurfref", + ("hipSurfaceReference_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUtexObject", ("hipTextureObject_t", CONV_TYPE, API_DRIVER)), + ("CUtexref", ("textureReference", CONV_TYPE, API_DRIVER)), + ("CUstream_flags", ("hipStreamFlags", CONV_TYPE, API_DRIVER)), + ( + "CUstreamWaitValue_flags", + ("hipStreamWaitValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamWriteValue_flags", + ("hipStreamWriteValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamBatchMemOpType", + ("hipStreamBatchMemOpType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUdevice_P2PAttribute", + ("hipDeviceP2PAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUevent", ("hipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_st", ("ihipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_flags", ("hipEventFlags", CONV_EVENT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUfilter_mode", ("hipTextureFilterMode", CONV_TEX, API_DRIVER)), + ("CUGLDeviceList", ("hipGLDeviceList", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("CUGLmap_flags", ("hipGLMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUd3d9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9map_flags", + ("hipD3D9MapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9register_flags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10map_flags", + ("hipD3D10MapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10register_flags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection_st", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType_t", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamCallback_t", ("hipStreamCallback_t", CONV_TYPE, API_RUNTIME)), + ("cudaArray", ("hipArray", CONV_MEM, API_RUNTIME)), + ("cudaArray_t", ("hipArray_t", CONV_MEM, API_RUNTIME)), + ("cudaArray_const_t", ("hipArray_const_t", CONV_MEM, API_RUNTIME)), + ("cudaMipmappedArray_t", ("hipMipmappedArray_t", CONV_MEM, API_RUNTIME)), + ( + "cudaMipmappedArray_const_t", + ("hipMipmappedArray_const_t", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayDefault", ("hipArrayDefault", CONV_MEM, API_RUNTIME)), + ("cudaArrayLayered", ("hipArrayLayered", CONV_MEM, API_RUNTIME)), + ( + "cudaArraySurfaceLoadStore", + ("hipArraySurfaceLoadStore", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayCubemap", ("hipArrayCubemap", CONV_MEM, API_RUNTIME)), + ("cudaArrayTextureGather", ("hipArrayTextureGather", CONV_MEM, API_RUNTIME)), + ("cudaMemoryAdvise", ("hipMemoryAdvise", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeAttribute", + ("hipMemRangeAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyKind", ("hipMemcpyKind", CONV_MEM, API_RUNTIME)), + ("cudaMemoryType", ("hipMemoryType", CONV_MEM, API_RUNTIME)), + ("cudaExtent", ("hipExtent", CONV_MEM, API_RUNTIME)), + ("cudaPitchedPtr", ("hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("cudaPos", ("hipPos", CONV_MEM, API_RUNTIME)), + ("cudaEvent_t", ("hipEvent_t", CONV_TYPE, API_RUNTIME)), + ("cudaStream_t", ("hipStream_t", CONV_TYPE, API_RUNTIME)), + ("cudaPointerAttributes", ("hipPointerAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceAttr", ("hipDeviceAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceProp", ("hipDeviceProp_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceP2PAttr", + ("hipDeviceP2PAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeMode", + ("hipComputeMode", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaFuncCache", ("hipFuncCache_t", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncAttributes", + ("hipFuncAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSharedMemConfig", ("hipSharedMemConfig", CONV_TYPE, API_RUNTIME)), + ("cudaLimit", ("hipLimit_t", CONV_TYPE, API_RUNTIME)), + ("cudaOutputMode", ("hipOutputMode", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaTextureReadMode", ("hipTextureReadMode", CONV_TEX, API_RUNTIME)), + ("cudaTextureFilterMode", ("hipTextureFilterMode", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatKind", ("hipChannelFormatKind", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatDesc", ("hipChannelFormatDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceDesc", ("hipResourceDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewDesc", ("hipResourceViewDesc", CONV_TEX, API_RUNTIME)), + ("cudaTextureDesc", ("hipTextureDesc", CONV_TEX, API_RUNTIME)), + ( + "surfaceReference", + ("hipSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureObject_t", ("hipTextureObject_t", CONV_TEX, API_RUNTIME)), + ("cudaResourceType", ("hipResourceType", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_RUNTIME)), + ("cudaTextureAddressMode", ("hipTextureAddressMode", CONV_TEX, API_RUNTIME)), + ( + "cudaSurfaceBoundaryMode", + ("hipSurfaceBoundaryMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSurfaceFormatMode", + ("hipSurfaceFormatMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureType1D", ("hipTextureType1D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType2D", ("hipTextureType2D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType3D", ("hipTextureType3D", CONV_TEX, API_RUNTIME)), + ("cudaTextureTypeCubemap", ("hipTextureTypeCubemap", CONV_TEX, API_RUNTIME)), + ( + "cudaTextureType1DLayered", + ("hipTextureType1DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureType2DLayered", + ("hipTextureType2DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureTypeCubemapLayered", + ("hipTextureTypeCubemapLayered", CONV_TEX, API_RUNTIME), + ), + ("cudaIpcEventHandle_t", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcEventHandle_st", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_t", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_st", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphicsCubeFace", + ("hipGraphicsCubeFace", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlags", + ("hipGraphicsMapFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceList", + ("hipGLDeviceList", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaGLMapFlags", ("hipGLMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaD3D9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapFlags", + ("hipD3D10MapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cublasHandle_t", ("hipblasHandle_t", CONV_TYPE, API_BLAS)), + ("cublasOperation_t", ("hipblasOperation_t", CONV_TYPE, API_BLAS)), + ("cublasStatus_t", ("hipblasStatus_t", CONV_TYPE, API_BLAS)), + ("cublasFillMode_t", ("hipblasFillMode_t", CONV_TYPE, API_BLAS)), + ("cublasDiagType_t", ("hipblasDiagType_t", CONV_TYPE, API_BLAS)), + ("cublasSideMode_t", ("hipblasSideMode_t", CONV_TYPE, API_BLAS)), + ("cublasPointerMode_t", ("hipblasPointerMode_t", CONV_TYPE, API_BLAS)), + ("cublasGemmAlgo_t", ("hipblasGemmAlgo_t", CONV_TYPE, API_BLAS)), + ( + "cublasAtomicsMode_t", + ("hipblasAtomicsMode_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDataType_t", + ("hipblasDatatype_t", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ("curandStatus", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandStatus_t", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandRngType", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandRngType_t", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandGenerator_st", ("hiprandGenerator_st", CONV_TYPE, API_RAND)), + ("curandGenerator_t", ("hiprandGenerator_t", CONV_TYPE, API_RAND)), + ( + "curandDirectionVectorSet", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDirectionVectorSet_t", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandOrdering", ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandOrdering_t", + ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_st", + ("hiprandDistributionShift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_t", + ("hiprandDistributionShift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_st", + ("hiprandDistributionM2Shift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_t", + ("hiprandDistributionM2Shift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_st", + ("hiprandHistogramM2_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_t", + ("hiprandHistogramM2_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_st", + ("hiprandHistogramM2K_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_t", + ("hiprandHistogramM2K_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDiscreteDistribution_st", + ("hiprandDiscreteDistribution_st", CONV_TYPE, API_RAND), + ), + ( + "curandDiscreteDistribution_t", + ("hiprandDiscreteDistribution_t", CONV_TYPE, API_RAND), + ), + ("curandMethod", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ("curandMethod_t", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandDirectionVectors32_t", + ("hiprandDirectionVectors32_t", CONV_TYPE, API_RAND), + ), + ( + "curandDirectionVectors64_t", + ("hiprandDirectionVectors64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateMtgp32_t", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ("curandStateMtgp32", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ( + "curandStateScrambledSobol64_t", + ("hiprandStateScrambledSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateSobol64_t", + ("hiprandStateSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateScrambledSobol32_t", + ("hiprandStateScrambledSobol32_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateSobol32_t", ("hiprandStateSobol32_t", CONV_TYPE, API_RAND)), + ("curandStateMRG32k3a_t", ("hiprandStateMRG32k3a_t", CONV_TYPE, API_RAND)), + ( + "curandStatePhilox4_32_10_t", + ("hiprandStatePhilox4_32_10_t", CONV_TYPE, API_RAND), + ), + ("curandStateXORWOW_t", ("hiprandStateXORWOW_t", CONV_TYPE, API_RAND)), + ("curandState_t", ("hiprandState_t", CONV_TYPE, API_RAND)), + ("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)), + ("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)), + ("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)), + ("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)), + ("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)), + ("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)), + ] +) + +CUDA_INCLUDE_MAP = collections.OrderedDict( + [ + # since pytorch uses "\b{pattern}\b" as the actual re pattern, + # patterns listed here have to begin and end with alnum chars + ( + "include " to differentiate + ("", (_RCCL_HEADER, CONV_INCLUDE, API_RUNTIME)), + ("nvrtc.h", ("hip/hiprtc.h", CONV_INCLUDE, API_RTC)), + ("thrust/system/cuda", ("thrust/system/hip", CONV_INCLUDE, API_BLAS)), + ("cub/util_allocator.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_raking_layout.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/cub.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/config.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/util_ptx.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/util_type.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_run_length_encode.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_load.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_store.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_radix_sort.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_select.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("nvtx3/nvtx3.hpp", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)), + ("nvml.h", ("rocm_smi/rocm_smi.h", CONV_INCLUDE, API_ROCMSMI)), + ] +) + +CUDA_IDENTIFIER_MAP = collections.OrderedDict( + [ + ("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)), + ( + "CUDA_ERROR_INVALID_CONTEXT", + ("hipErrorInvalidContext", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", + ("hipErrorContextAlreadyCurrent", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_ARRAY_IS_MAPPED", + ("hipErrorArrayIsMapped", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_ALREADY_MAPPED", ("hipErrorAlreadyMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_ALREADY_ACQUIRED", + ("hipErrorAlreadyAcquired", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_NOT_MAPPED", ("hipErrorNotMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_NOT_MAPPED_AS_ARRAY", + ("hipErrorNotMappedAsArray", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_NOT_MAPPED_AS_POINTER", + ("hipErrorNotMappedAsPointer", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_IN_USE", + ("hipErrorContextAlreadyInUse", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_INVALID_SOURCE", ("hipErrorInvalidSource", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_FILE_NOT_FOUND", ("hipErrorFileNotFound", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_NOT_FOUND", ("hipErrorNotFound", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING", + ( + "hipErrorLaunchIncompatibleTexturing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE", + ("hipErrorPrimaryContextActive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_CONTEXT_IS_DESTROYED", + ("hipErrorContextIsDestroyed", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_PERMITTED", + ("hipErrorNotPermitted", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_SUPPORTED", + ("hipErrorNotSupported", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMissingConfiguration", + ("hipErrorMissingConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorPriorLaunchFailure", + ("hipErrorPriorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDeviceFunction", + ("hipErrorInvalidDeviceFunction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidConfiguration", + ("hipErrorInvalidConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPitchValue", + ("hipErrorInvalidPitchValue", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidSymbol", + ("hipErrorInvalidSymbol", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidHostPointer", + ("hipErrorInvalidHostPointer", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDevicePointer", + ("hipErrorInvalidDevicePointer", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaErrorInvalidTexture", + ("hipErrorInvalidTexture", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidTextureBinding", + ("hipErrorInvalidTextureBinding", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidChannelDescriptor", + ( + "hipErrorInvalidChannelDescriptor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorInvalidMemcpyDirection", + ("hipErrorInvalidMemcpyDirection", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAddressOfConstant", + ("hipErrorAddressOfConstant", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureFetchFailed", + ("hipErrorTextureFetchFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureNotBound", + ("hipErrorTextureNotBound", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSynchronizationError", + ("hipErrorSynchronizationError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidFilterSetting", + ("hipErrorInvalidFilterSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidNormSetting", + ("hipErrorInvalidNormSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMixedDeviceExecution", + ("hipErrorMixedDeviceExecution", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotYetImplemented", + ("hipErrorNotYetImplemented", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMemoryValueTooLarge", + ("hipErrorMemoryValueTooLarge", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInsufficientDriver", + ("hipErrorInsufficientDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSetOnActiveProcess", + ("hipErrorSetOnActiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorContextIsDestroyed", + ("hipErrorContextIsDestroyed", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaErrorInvalidSurface", + ("hipErrorInvalidSurface", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateVariableName", + ("hipErrorDuplicateVariableName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateTextureName", + ("hipErrorDuplicateTextureName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateSurfaceName", + ("hipErrorDuplicateSurfaceName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDevicesUnavailable", + ("hipErrorDevicesUnavailable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIncompatibleDriverContext", + ( + "hipErrorIncompatibleDriverContext", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorDeviceAlreadyInUse", + ("hipErrorDeviceAlreadyInUse", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchMaxDepthExceeded", + ("hipErrorLaunchMaxDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedTex", + ("hipErrorLaunchFileScopedTex", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedSurf", + ("hipErrorLaunchFileScopedSurf", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSyncDepthExceeded", + ("hipErrorSyncDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchPendingCountExceeded", + ( + "hipErrorLaunchPendingCountExceeded", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorNotPermitted", + ("hipErrorNotPermitted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotSupported", + ("hipErrorNotSupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorStartupFailure", + ("hipErrorStartupFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorApiFailureBase", + ("hipErrorApiFailureBase", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_SUCCESS", ("hipSuccess", CONV_TYPE, API_DRIVER)), + ("cudaSuccess", ("hipSuccess", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_VALUE", ("hipErrorInvalidValue", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidValue", ("hipErrorInvalidValue", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_OUT_OF_MEMORY", + ("hipErrorMemoryAllocation", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorMemoryAllocation", + ("hipErrorMemoryAllocation", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_NOT_INITIALIZED", + ("hipErrorNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInitializationError", + ("hipErrorInitializationError", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_DEINITIALIZED", ("hipErrorDeinitialized", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorCudartUnloading", + ("hipErrorDeinitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_DISABLED", + ("hipErrorProfilerDisabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerDisabled", + ("hipErrorProfilerDisabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_NOT_INITIALIZED", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerNotInitialized", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STARTED", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStarted", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STOPPED", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStopped", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_NO_DEVICE", ("hipErrorNoDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorNoDevice", ("hipErrorNoDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_DEVICE", ("hipErrorInvalidDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidDevice", ("hipErrorInvalidDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_IMAGE", ("hipErrorInvalidImage", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorInvalidKernelImage", + ("hipErrorInvalidImage", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_MAP_FAILED", ("hipErrorMapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorMapBufferObjectFailed", + ("hipErrorMapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_UNMAP_FAILED", ("hipErrorUnmapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorUnmapBufferObjectFailed", + ("hipErrorUnmapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NO_BINARY_FOR_GPU", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorNoKernelImageForDevice", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ECC_UNCORRECTABLE", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorECCUncorrectable", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNSUPPORTED_LIMIT", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorUnsupportedLimit", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_UNSUPPORTED", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessUnsupported", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PTX", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidPtx", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_GRAPHICS_CONTEXT", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidGraphicsContext", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NVLINK_UNCORRECTABLE", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNvlinkUncorrectable", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND", + ("hipErrorSharedObjectSymbolNotFound", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectSymbolNotFound", + ( + "hipErrorSharedObjectSymbolNotFound", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_INIT_FAILED", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectInitFailed", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_OPERATING_SYSTEM", + ("hipErrorOperatingSystem", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorOperatingSystem", + ("hipErrorOperatingSystem", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_HANDLE", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidResourceHandle", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_NOT_READY", ("hipErrorNotReady", CONV_TYPE, API_DRIVER)), + ("cudaErrorNotReady", ("hipErrorNotReady", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_ILLEGAL_ADDRESS", + ("hipErrorIllegalAddress", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorIllegalAddress", + ("hipErrorIllegalAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorLaunchOutOfResources", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_LAUNCH_TIMEOUT", ("hipErrorLaunchTimeOut", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorLaunchTimeout", + ("hipErrorLaunchTimeOut", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessAlreadyEnabled", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessNotEnabled", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_ASSERT", + ("hipErrorAssert", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAssert", + ("hipErrorAssert", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_TOO_MANY_PEERS", + ("hipErrorTooManyPeers", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTooManyPeers", + ("hipErrorTooManyPeers", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryAlreadyRegistered", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryNotRegistered", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HARDWARE_STACK_ERROR", + ("hipErrorHardwareStackError", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorHardwareStackError", + ("hipErrorHardwareStackError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ILLEGAL_INSTRUCTION", + ("hipErrorIllegalInstruction", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIllegalInstruction", + ("hipErrorIllegalInstruction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_MISALIGNED_ADDRESS", + ("hipErrorMisalignedAddress", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMisalignedAddress", + ("hipErrorMisalignedAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_ADDRESS_SPACE", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidAddressSpace", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PC", + ("hipErrorInvalidPc", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPc", + ("hipErrorInvalidPc", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_FAILED", + ("hipErrorLaunchFailure", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFailure", + ("hipErrorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNKNOWN", + ("hipErrorUnknown", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cudaErrorUnknown", ("hipErrorUnknown", CONV_TYPE, API_RUNTIME)), + ( + "CU_TR_ADDRESS_MODE_WRAP", + ("HIP_TR_ADDRESS_MODE_WRAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_CLAMP", + ("HIP_TR_ADDRESS_MODE_CLAMP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_MIRROR", + ("HIP_TR_ADDRESS_MODE_MIRROR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_BORDER", + ("HIP_TR_ADDRESS_MODE_BORDER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_X", + ("HIP_CUBEMAP_FACE_POSITIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_X", + ("HIP_CUBEMAP_FACE_NEGATIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Y", + ("HIP_CUBEMAP_FACE_POSITIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Y", + ("HIP_CUBEMAP_FACE_NEGATIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Z", + ("HIP_CUBEMAP_FACE_POSITIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Z", + ("HIP_CUBEMAP_FACE_NEGATIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT8", + ("HIP_AD_FORMAT_UNSIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT16", + ("HIP_AD_FORMAT_UNSIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT32", + ("HIP_AD_FORMAT_UNSIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT8", + ("HIP_AD_FORMAT_SIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT16", + ("HIP_AD_FORMAT_SIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT32", + ("HIP_AD_FORMAT_SIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ("CU_AD_FORMAT_HALF", ("HIP_AD_FORMAT_HALF", CONV_TYPE, API_DRIVER)), + ("CU_AD_FORMAT_FLOAT", ("HIP_AD_FORMAT_FLOAT", CONV_TYPE, API_DRIVER)), + ( + "CU_COMPUTEMODE_DEFAULT", + ("hipComputeModeDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE", + ("hipComputeModeExclusive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_PROHIBITED", + ("hipComputeModeProhibited", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE_PROCESS", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_READ_MOSTLY", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_READ_MOSTLY", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_PREFERRED_LOCATION", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_SET_ACCESSED_BY", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_ACCESSED_BY", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_CTX_SCHED_AUTO", + ("HIP_CTX_SCHED_AUTO", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_SPIN", + ("HIP_CTX_SCHED_SPIN", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_YIELD", + ("HIP_CTX_SCHED_YIELD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_BLOCKING_SYNC", + ("HIP_CTX_SCHED_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_BLOCKING_SYNC", + ("HIP_CTX_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_MASK", + ("HIP_CTX_SCHED_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_MAP_HOST", + ("HIP_CTX_MAP_HOST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_LMEM_RESIZE_TO_MAX", + ("HIP_CTX_LMEM_RESIZE_TO_MAX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_FLAGS_MASK", + ("HIP_CTX_FLAGS_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_POINTER", + ("HIP_LAUNCH_PARAM_BUFFER_POINTER", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_SIZE", + ("HIP_LAUNCH_PARAM_BUFFER_SIZE", CONV_TYPE, API_DRIVER), + ), + ("CU_LAUNCH_PARAM_END", ("HIP_LAUNCH_PARAM_END", CONV_TYPE, API_DRIVER)), + ( + "CU_IPC_HANDLE_SIZE", + ("HIP_IPC_HANDLE_SIZE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_DEVICEMAP", + ("HIP_MEMHOSTALLOC_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_PORTABLE", + ("HIP_MEMHOSTALLOC_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_WRITECOMBINED", + ("HIP_MEMHOSTALLOC_WRITECOMBINED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_DEVICEMAP", + ("HIP_MEMHOSTREGISTER_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_IOMEMORY", + ("HIP_MEMHOSTREGISTER_IOMEMORY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_PORTABLE", + ("HIP_MEMHOSTREGISTER_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PARAM_TR_DEFAULT", + ("HIP_PARAM_TR_DEFAULT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_LEGACY", + ("HIP_STREAM_LEGACY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_PER_THREAD", + ("HIP_STREAM_PER_THREAD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSA_OVERRIDE_FORMAT", + ("HIP_TRSA_OVERRIDE_FORMAT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_NORMALIZED_COORDINATES", + ("HIP_TRSF_NORMALIZED_COORDINATES", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_READ_AS_INTEGER", + ("HIP_TRSF_READ_AS_INTEGER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TRSF_SRGB", ("HIP_TRSF_SRGB", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_ARRAY3D_2DARRAY", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_CUBEMAP", + ("HIP_ARRAY3D_CUBEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_DEPTH_TEXTURE", + ("HIP_ARRAY3D_DEPTH_TEXTURE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_LAYERED", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_SURFACE_LDST", + ("HIP_ARRAY3D_SURFACE_LDST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_TEXTURE_GATHER", + ("HIP_ARRAY3D_TEXTURE_GATHER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipDeviceAttributeMaxThreadsPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY", + ( + "hipDeviceAttributeTotalConstantMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_WARP_SIZE", + ("hipDeviceAttributeWarpSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_PITCH", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CLOCK_RATE", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GPU_OVERLAP", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT", + ( + "hipDeviceAttributeMultiprocessorCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_INTEGRATED", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_MODE", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ECC_ENABLED", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", + ("hipDeviceAttributePciBusId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_TCC_DRIVER", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE", + ( + "hipDeviceAttributeMemoryClockRate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER", + ( + "hipDeviceAttributeCanTex2DGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY", + ("hipDeviceAttributeManagedMemory", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID", + ( + "hipDeviceAttributeMultiGpuBoardGroupId", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX", + ("hipDeviceAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_CONTEXT", + ("hipPointerAttributeContext", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_MEMORY_TYPE", + ("hipPointerAttributeMemoryType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_DEVICE_POINTER", + ( + "hipPointerAttributeDevicePointer", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_POINTER_ATTRIBUTE_HOST_POINTER", + ("hipPointerAttributeHostPointer", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_P2P_TOKENS", + ("hipPointerAttributeP2pTokens", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_SYNC_MEMOPS", + ("hipPointerAttributeSyncMemops", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_BUFFER_ID", + ("hipPointerAttributeBufferId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_IS_MANAGED", + ("hipPointerAttributeIsManaged", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipFuncAttributeMaxThreadsPerBlocks", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES", + ("hipFuncAttributeSharedSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES", + ("hipFuncAttributeMaxDynamicSharedMemorySize", CONV_TYPE, API_RUNTIME), + ), + ( + "CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES", + ("hipFuncAttributeConstSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES", + ("hipFuncAttributeLocalSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_NUM_REGS", + ("hipFuncAttributeNumRegs", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_PTX_VERSION", + ("hipFuncAttributePtxVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_BINARY_VERSION", + ("hipFuncAttributeBinaryVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_CACHE_MODE_CA", + ("hipFuncAttributeCacheModeCA", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX", + ("hipFuncAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE", + ("hipGraphicsMapFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY", + ("hipGraphicsMapFlagsReadOnly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ("hipGraphicsMapFlagsWriteDiscard", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_NONE", + ("hipGraphicsRegisterFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_OCCUPANCY_DEFAULT", + ("hipOccupancyDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_CACHE_PREFER_NONE", + ("hipFuncCachePreferNone", CONV_CACHE, API_DRIVER), + ), + ( + "CU_FUNC_CACHE_PREFER_SHARED", + ("hipFuncCachePreferShared", CONV_CACHE, API_DRIVER), + ), + ("CU_FUNC_CACHE_PREFER_L1", ("hipFuncCachePreferL1", CONV_CACHE, API_DRIVER)), + ( + "CU_FUNC_CACHE_PREFER_EQUAL", + ("hipFuncCachePreferEqual", CONV_CACHE, API_DRIVER), + ), + ( + "CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_IPC_HANDLE_SIZE", ("HIP_IPC_HANDLE_SIZE", CONV_TYPE, API_DRIVER)), + ( + "CU_JIT_CACHE_OPTION_NONE", + ("hipJitCacheModeOptionNone", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CG", + ("hipJitCacheModeOptionCG", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CA", + ("hipJitCacheModeOptionCA", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_PTX", + ("hipJitFallbackPreferPtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_BINARY", + ("hipJitFallbackPreferBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_JIT_MAX_REGISTERS", ("hipJitOptionMaxRegisters", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_THREADS_PER_BLOCK", + ("hipJitOptionThreadsPerBlock", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_WALL_TIME", ("hipJitOptionWallTime", CONV_JIT, API_DRIVER)), + ("CU_JIT_INFO_LOG_BUFFER", ("hipJitOptionInfoLogBuffer", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionInfoLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER", + ("hipJitOptionErrorLogBuffer", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionErrorLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_OPTIMIZATION_LEVEL", + ("hipJitOptionOptimizationLevel", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_TARGET_FROM_CUCONTEXT", + ("hipJitOptionTargetFromContext", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_TARGET", ("hipJitOptionTarget", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_FALLBACK_STRATEGY", + ("hipJitOptionFallbackStrategy", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_GENERATE_DEBUG_INFO", + ("hipJitOptionGenerateDebugInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_LOG_VERBOSE", ("hipJitOptionLogVerbose", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_GENERATE_LINE_INFO", + ("hipJitOptionGenerateLineInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_CACHE_MODE", ("hipJitOptionCacheMode", CONV_JIT, API_DRIVER)), + ("CU_JIT_NEW_SM3X_OPT", ("hipJitOptionSm3xOpt", CONV_JIT, API_DRIVER)), + ("CU_JIT_FAST_COMPILE", ("hipJitOptionFastCompile", CONV_JIT, API_DRIVER)), + ("CU_JIT_NUM_OPTIONS", ("hipJitOptionNumOptions", CONV_JIT, API_DRIVER)), + ( + "CU_TARGET_COMPUTE_10", + ("hipJitTargetCompute10", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_11", + ("hipJitTargetCompute11", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_12", + ("hipJitTargetCompute12", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_13", + ("hipJitTargetCompute13", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_20", + ("hipJitTargetCompute20", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_21", + ("hipJitTargetCompute21", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_30", + ("hipJitTargetCompute30", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_32", + ("hipJitTargetCompute32", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_35", + ("hipJitTargetCompute35", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_37", + ("hipJitTargetCompute37", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_50", + ("hipJitTargetCompute50", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_52", + ("hipJitTargetCompute52", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_53", + ("hipJitTargetCompute53", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_60", + ("hipJitTargetCompute60", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_61", + ("hipJitTargetCompute61", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_62", + ("hipJitTargetCompute62", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_CUBIN", + ("hipJitInputTypeBin", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_PTX", + ("hipJitInputTypePtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_FATBINARY", + ("hipJitInputTypeFatBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_OBJECT", + ("hipJitInputTypeObject", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_LIBRARY", + ("hipJitInputTypeLibrary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_NUM_INPUT_TYPES", + ("hipJitInputTypeNumInputTypes", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_PRINTF_FIFO_SIZE", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_MALLOC_HEAP_SIZE", + ("hipLimitMallocHeapSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_GLOBAL", + ("hipMemAttachGlobal", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_HOST", + ("hipMemAttachHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_SINGLE", + ("hipMemAttachSingle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_HOST", + ("hipMemTypeHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_DEVICE", + ("hipMemTypeDevice", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_ARRAY", + ("hipMemTypeArray", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_UNIFIED", + ("hipMemTypeUnified", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_ARRAY", + ("hipResourceTypeArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_MIPMAPPED_ARRAY", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_LINEAR", + ("hipResourceTypeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_PITCH2D", + ("hipResourceTypePitch2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_RES_VIEW_FORMAT_NONE", ("hipResViewFormatNone", CONV_TEX, API_DRIVER)), + ( + "CU_RES_VIEW_FORMAT_UINT_1X8", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X8", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X8", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X8", + ("hipResViewFormatSignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X8", + ("hipResViewFormatSignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X8", + ("hipResViewFormatSignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X16", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X16", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X16", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X16", + ("hipResViewFormatSignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X16", + ("hipResViewFormatSignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X16", + ("hipResViewFormatSignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X32", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X32", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X32", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X32", + ("hipResViewFormatSignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X32", + ("hipResViewFormatSignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X32", + ("hipResViewFormatSignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X16", + ("hipResViewFormatHalf1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X16", + ("hipResViewFormatHalf2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X16", + ("hipResViewFormatHalf4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X32", + ("hipResViewFormatFloat1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X32", + ("hipResViewFormatFloat2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X32", + ("hipResViewFormatFloat4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_DRIVER), + ), + ("CU_STREAM_DEFAULT", ("hipStreamDefault", CONV_TYPE, API_DRIVER)), + ("CU_STREAM_NON_BLOCKING", ("hipStreamNonBlocking", CONV_TYPE, API_DRIVER)), + ( + "CU_STREAM_WAIT_VALUE_GEQ", + ("hipStreamWaitValueGeq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_EQ", + ("hipStreamWaitValueEq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_AND", + ("hipStreamWaitValueAnd", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_FLUSH", + ("hipStreamWaitValueFlush", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_DEFAULT", + ("hipStreamWriteValueDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER", + ( + "hipStreamWriteValueNoMemoryBarrier", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_STREAM_MEM_OP_WAIT_VALUE_32", + ("hipStreamBatchMemOpWaitValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_WRITE_VALUE_32", + ("hipStreamBatchMemOpWriteValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES", + ( + "hipStreamBatchMemOpFlushRemoteWrites", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGetErrorName", + ("hipGetErrorName", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGetErrorString", + ("hipDrvGetErrorString", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuInit", ("hipInit", CONV_INIT, API_DRIVER)), + ("cuDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_DRIVER)), + ("cuCtxCreate", ("hipCtxCreate", CONV_CONTEXT, API_DRIVER)), + ("cuCtxCreate_v2", ("hipCtxCreate", CONV_CONTEXT, API_DRIVER)), + ("cuCtxDestroy", ("hipCtxDestroy", CONV_CONTEXT, API_DRIVER)), + ("cuCtxDestroy_v2", ("hipCtxDestroy", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetApiVersion", ("hipCtxGetApiVersion", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCacheConfig", ("hipCtxGetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCurrent", ("hipCtxGetCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetDevice", ("hipCtxGetDevice", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetFlags", ("hipCtxGetFlags", CONV_CONTEXT, API_DRIVER)), + ("cuDeviceGetUuid", ("hipDeviceGetUuid", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxGetLimit", + ("hipCtxGetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxGetSharedMemConfig", + ("hipCtxGetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuCtxGetStreamPriorityRange", + ("hipCtxGetStreamPriorityRange", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuCtxPopCurrent_v2", ("hipCtxPopCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxPushCurrent_v2", ("hipCtxPushCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCacheConfig", ("hipCtxSetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCurrent", ("hipCtxSetCurrent", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxSetLimit", + ("hipCtxSetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxSetSharedMemConfig", + ("hipCtxSetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ("cuCtxSynchronize", ("hipCtxSynchronize", CONV_CONTEXT, API_DRIVER)), + ("cuCtxAttach", ("hipCtxAttach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxDetach", ("hipCtxDetach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxEnablePeerAccess", ("hipCtxEnablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuCtxDisablePeerAccess", ("hipCtxDisablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_DRIVER)), + ( + "cuDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_PEER, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuDevicePrimaryCtxGetState", + ("hipDevicePrimaryCtxGetState", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRelease", + ("hipDevicePrimaryCtxRelease", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxReset", + ("hipDevicePrimaryCtxReset", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRetain", + ("hipDevicePrimaryCtxRetain", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxSetFlags", + ("hipDevicePrimaryCtxSetFlags", CONV_CONTEXT, API_DRIVER), + ), + ("cuDeviceGet", ("hipDeviceGet", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetName", ("hipDeviceGetName", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetCount", ("hipGetDeviceCount", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetByPCIBusId", ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceTotalMem_v2", ("hipDeviceTotalMem", CONV_DEVICE, API_DRIVER)), + ( + "cuDeviceComputeCapability", + ("hipDeviceComputeCapability", CONV_DEVICE, API_DRIVER), + ), + ("cuDeviceGetProperties", ("hipGetDeviceProperties", CONV_DEVICE, API_DRIVER)), + ("cuLinkAddData", ("hipLinkAddData", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkAddFile", ("hipLinkAddFile", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLinkComplete", + ("hipLinkComplete", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLinkCreate", ("hipLinkCreate", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkDestroy", ("hipLinkDestroy", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuModuleGetFunction", ("hipModuleGetFunction", CONV_MODULE, API_DRIVER)), + ("cuModuleGetGlobal_v2", ("hipModuleGetGlobal", CONV_MODULE, API_DRIVER)), + ( + "cuModuleGetSurfRef", + ("hipModuleGetSurfRef", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleGetTexRef", ("hipModuleGetTexRef", CONV_MODULE, API_DRIVER)), + ("cuModuleLoad", ("hipModuleLoad", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadData", ("hipModuleLoadData", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadDataEx", ("hipModuleLoadDataEx", CONV_MODULE, API_DRIVER)), + ( + "cuModuleLoadFatBinary", + ("hipModuleLoadFatBinary", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleUnload", ("hipModuleUnload", CONV_MODULE, API_DRIVER)), + ( + "CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("CU_EVENT_DEFAULT", ("hipEventDefault", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_BLOCKING_SYNC", ("hipEventBlockingSync", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_DISABLE_TIMING", ("hipEventDisableTiming", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_INTERPROCESS", ("hipEventInterprocess", CONV_EVENT, API_DRIVER)), + ("cuEventCreate", ("hipEventCreate", CONV_EVENT, API_DRIVER)), + ("cuEventDestroy", ("hipEventDestroy", CONV_EVENT, API_DRIVER)), + ("cuEventDestroy_v2", ("hipEventDestroy", CONV_EVENT, API_DRIVER)), + ("cuEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_DRIVER)), + ("cuEventQuery", ("hipEventQuery", CONV_EVENT, API_DRIVER)), + ("cuEventRecord", ("hipEventRecord", CONV_EVENT, API_DRIVER)), + ("cuEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_DRIVER)), + ("cuFuncSetAttribute", ("hipFuncSetAttribute", CONV_EVENT, API_DRIVER)), + ( + "cuFuncGetAttribute", + ("hipFuncGetAttribute", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunchKernel", ("hipModuleLaunchKernel", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetBlockShape", + ("hipFuncSetBlockShape", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cudaLaunchKernel", ("hipLaunchKernel", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetSharedSize", + ("hipFuncSetSharedSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunch", ("hipLaunch", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLaunchGrid", ("hipLaunchGrid", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLaunchGridAsync", + ("hipLaunchGridAsync", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetf", ("hipParamSetf", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuParamSeti", ("hipParamSeti", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetv", ("hipParamSetv", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipModuleOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_DRIVER, + ), + ), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipModuleOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuOccupancyMaxPotentialBlockSize", + ("hipModuleOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_DRIVER), + ), + ( + "cuOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipModuleOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_DRIVER)), + ( + "cuStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreate", + ("hipStreamCreate__", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamDestroy", ("hipStreamDestroy", CONV_STREAM, API_DRIVER)), + ("cuStreamDestroy_v2", ("hipStreamDestroy", CONV_STREAM, API_DRIVER)), + ("cuStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_DRIVER)), + ( + "cuStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamQuery", ("hipStreamQuery", CONV_STREAM, API_DRIVER)), + ("cuStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_DRIVER)), + ("cuStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_DRIVER)), + ( + "cuStreamWaitValue32", + ("hipStreamWaitValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamWriteValue32", + ("hipStreamWriteValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamBatchMemOp", + ("hipStreamBatchMemOp", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArray3DCreate", ("hipArray3DCreate", CONV_MEM, API_DRIVER)), + ( + "cuArray3DGetDescriptor", + ("hipArray3DGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArrayCreate", ("hipArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuArrayDestroy", ("hipArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuArrayGetDescriptor", + ("hipArrayGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcCloseMemHandle", + ("hipIpcCloseMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetEventHandle", + ("hipIpcGetEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetMemHandle", + ("hipIpcGetMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenEventHandle", + ("hipIpcOpenEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenMemHandle", + ("hipIpcOpenMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAlloc_v2", ("hipMalloc", CONV_MEM, API_DRIVER)), + ("cuMemAllocHost", ("hipMemAllocHost", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemAllocManaged", + ("hipMemAllocManaged", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemAllocPitch", + ("hipMemAllocPitch__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy", ("hipMemcpy__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpy2D", ("hipMemcpy2D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy2DAsync", + ("hipMemcpy2DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy2DUnaligned", + ("hipMemcpy2DUnaligned", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy3D", ("hipMemcpy3D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy3DAsync", + ("hipMemcpy3DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeer", + ("hipMemcpy3DPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyAsync", ("hipMemcpyAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoA", ("hipMemcpyAtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoD", ("hipMemcpyAtoD", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoH", ("hipMemcpyAtoH", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyAtoHAsync", + ("hipMemcpyAtoHAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyDtoA", ("hipMemcpyDtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyDtoD_v2", ("hipMemcpyDtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoDAsync_v2", ("hipMemcpyDtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoH_v2", ("hipMemcpyDtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoHAsync_v2", ("hipMemcpyDtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoA", ("hipMemcpyHtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyHtoAAsync", + ("hipMemcpyHtoAAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyHtoD_v2", ("hipMemcpyHtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoDAsync_v2", ("hipMemcpyHtoDAsync", CONV_MEM, API_DRIVER)), + ( + "cuMemcpyPeerAsync", + ("hipMemcpyPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyPeer", ("hipMemcpyPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemFree", ("hipFree", CONV_MEM, API_DRIVER)), + ("cuMemFree_v2", ("hipFree", CONV_MEM, API_DRIVER)), + ("cuMemFreeHost", ("hipHostFree", CONV_MEM, API_DRIVER)), + ( + "cuMemGetAddressRange", + ("hipMemGetAddressRange", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemGetInfo_v2", ("hipMemGetInfo", CONV_MEM, API_DRIVER)), + ("cuMemHostAlloc", ("hipHostMalloc", CONV_MEM, API_DRIVER)), + ( + "cuMemHostGetDevicePointer", + ("hipMemHostGetDevicePointer", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemHostGetFlags", + ("hipMemHostGetFlags", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemHostRegister_v2", ("hipHostRegister", CONV_MEM, API_DRIVER)), + ("cuMemHostUnregister", ("hipHostUnregister", CONV_MEM, API_DRIVER)), + ("cuMemsetD16_v2", ("hipMemsetD16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD16Async", + ("hipMemsetD16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D16_v2", ("hipMemsetD2D16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D16Async", + ("hipMemsetD2D16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D32_v2", ("hipMemsetD2D32", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D32Async", + ("hipMemsetD2D32Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D8_v2", ("hipMemsetD2D8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D8Async", + ("hipMemsetD2D8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD32_v2", ("hipMemset", CONV_MEM, API_DRIVER)), + ("cuMemsetD32Async", ("hipMemsetAsync", CONV_MEM, API_DRIVER)), + ("cuMemsetD8_v2", ("hipMemsetD8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD8Async", + ("hipMemsetD8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayCreate", + ("hipMipmappedArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayDestroy", + ("hipMipmappedArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayGetLevel", + ("hipMipmappedArrayGetLevel", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemPrefetchAsync", + ("hipMemPrefetchAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAdvise", ("hipMemAdvise", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerGetAttribute", + ("hipPointerGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemGetAddressRange_v2", + ("hipMemGetAddressRange", CONV_MEM, API_DRIVER), + ), + ( + "cuPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerSetAttribute", + ("hipPointerSetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TR_FILTER_MODE_POINT", ("hipFilterModePoint", CONV_TEX, API_DRIVER)), + ( + "CU_TR_FILTER_MODE_LINEAR", + ("hipFilterModeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddress", + ("hipTexRefGetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddressMode", + ("hipTexRefGetAddressMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetArray", + ("hipTexRefGetArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetBorderColor", + ("hipTexRefGetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFilterMode", + ("hipTexRefGetFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFlags", + ("hipTexRefGetFlags", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFormat", + ("hipTexRefGetFormat", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMaxAnisotropy", + ("hipTexRefGetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapFilterMode", + ("hipTexRefGetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelBias", + ("hipTexRefGetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelClamp", + ("hipTexRefGetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmappedArray", + ("hipTexRefGetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress", + ("hipTexRefSetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress2D", + ("hipTexRefSetAddress2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetAddressMode", ("hipTexRefSetAddressMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetArray", ("hipTexRefSetArray", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetBorderColor", + ("hipTexRefSetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetFilterMode", ("hipTexRefSetFilterMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFlags", ("hipTexRefSetFlags", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFormat", ("hipTexRefSetFormat", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetMaxAnisotropy", + ("hipTexRefSetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapFilterMode", + ("hipTexRefSetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelBias", + ("hipTexRefSetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelClamp", + ("hipTexRefSetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmappedArray", + ("hipTexRefSetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefCreate", ("hipTexRefCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuTexRefDestroy", + ("hipTexRefDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefGetArray", + ("hipSurfRefGetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefSetArray", + ("hipSurfRefSetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectCreate", + ("hipTexObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectDestroy", + ("hipTexObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceDesc", + ("hipTexObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceViewDesc", + ("hipTexObjectGetResourceViewDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetTextureDesc", + ("hipTexObjectGetTextureDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectCreate", + ("hipSurfObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectDestroy", + ("hipSurfObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectGetResourceDesc", + ("hipSurfObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuProfilerInitialize", + ("hipProfilerInitialize", CONV_OTHER, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuProfilerStart", ("hipProfilerStart", CONV_OTHER, API_DRIVER)), + ("cuProfilerStop", ("hipProfilerStop", CONV_OTHER, API_DRIVER)), + ( + "CU_GL_DEVICE_LIST_ALL", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_CURRENT_FRAME", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_NEXT_FRAME", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuGLGetDevices", ("hipGLGetDevices", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuWGLGetDevice", ("hipWGLGetDevice", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CU_GL_MAP_RESOURCE_FLAGS_NONE", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuGLCtxCreate", ("hipGLCtxCreate", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("cuGLInit", ("hipGLInit", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGLMapBufferObject", + ("hipGLMapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_ALL", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_DEVICE_LIST_NEXT_FRAME", + ("HIP_D3D9_DEVICE_LIST_NEXT_FRAME", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreate", + ("hipD3D9CtxCreate", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreateOnDevice", + ("hipD3D9CtxCreateOnDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D9RegisterResource", + ("hipGraphicsD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_NONE", + ("HIP_D3D9_MAPRESOURCE_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_REGISTER_FLAGS_NONE", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_REGISTER_FLAGS_ARRAY", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPointer", + ("hipD3D9ResourceGetMappedPointer", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_ALL", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_NONE", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_REGISTER_FLAGS_NONE", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_REGISTER_FLAGS_ARRAY", + ("HIP_D3D10_REGISTER_FLAGS_ARRAY", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreate", + ("hipD3D10CtxCreate", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreateOnDevice", + ("hipD3D10CtxCreateOnDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedArray", + ("hipD3D10ResourceGetMappedArray", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPitch", + ("hipD3D10ResourceGetMappedPitch", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD310ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_ALL", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D11_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11CtxCreate", + ("hipD3D11CtxCreate", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11CtxCreateOnDevice", + ("hipD3D11CtxCreateOnDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDirect3DDevice", + ("hipD3D11GetDirect3DDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuVDPAUCtxCreate", + ("hipVDPAUCtxCreate", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerAcquireFrame", + ("hipEGLStreamConsumerAcquireFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuEGLStreamConsumerDisconnect", + ("hipEGLStreamConsumerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerReleaseFrame", + ("hipEGLStreamConsumerReleaseFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerPresentFrame", + ("hipEGLStreamProducerPresentFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cudaDataType_t", ("hipDataType", CONV_TYPE, API_RUNTIME)), + ("cudaDataType", ("hipDataType", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32F", ("HIP_R_32F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64F", ("HIP_R_64F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16F", ("HIP_R_16F", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8I", ("HIP_R_8I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32F", ("HIP_C_32F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64F", ("HIP_C_64F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16F", ("HIP_C_16F", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_8I", ("HIP_C_8I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8U", ("HIP_R_8U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_8U", ("HIP_C_8U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32I", ("HIP_R_32I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32I", ("HIP_C_32I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_32U", ("HIP_R_32U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_32U", ("HIP_C_32U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16BF", ("HIP_R_16BF", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16BF", ("HIP_C_16BF", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_4I", ("HIP_R_4I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_4I", ("HIP_C_4I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_4U", ("HIP_R_4U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_4U", ("HIP_C_4U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16I", ("HIP_R_16I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16I", ("HIP_C_16I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_16U", ("HIP_R_16U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_16U", ("HIP_C_16U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64I", ("HIP_R_64I", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64I", ("HIP_C_64I", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_64U", ("HIP_R_64U", CONV_TYPE, API_RUNTIME)), + ("CUDA_C_64U", ("HIP_C_64U", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8F_E4M3", ("HIP_R_8F_E4M3", CONV_TYPE, API_RUNTIME)), + ("CUDA_R_8F_E5M2", ("HIP_R_8F_E5M2", CONV_TYPE, API_RUNTIME)), + ( + "MAJOR_VERSION", + ("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "MINOR_VERSION", + ("hipLibraryMinorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "PATCH_LEVEL", + ("hipLibraryPatchVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachGlobal", + ("hipMemAttachGlobal", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachHost", + ("hipMemAttachHost", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachSingle", + ("hipMemAttachSingle", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDefault", + ("hipOccupancyDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDisableCachingOverride", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaGetLastError", ("hipGetLastError", CONV_ERROR, API_RUNTIME)), + ("cudaPeekAtLastError", ("hipPeekAtLastError", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorName", ("hipGetErrorName", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorString", ("hipGetErrorString", CONV_ERROR, API_RUNTIME)), + ("cudaMemcpy3DParms", ("hipMemcpy3DParms", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DPeerParms", + ("hipMemcpy3DPeerParms", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy", ("hipMemcpy", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToArray", ("hipMemcpyToArray", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbol", ("hipMemcpyToSymbol", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbolAsync", ("hipMemcpyToSymbolAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyAsync", ("hipMemcpyAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2D", ("hipMemcpy2D", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DAsync", ("hipMemcpy2DAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DToArray", ("hipMemcpy2DToArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy2DArrayToArray", + ("hipMemcpy2DArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArray", + ("hipMemcpy2DFromArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArrayAsync", + ("hipMemcpy2DFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DToArrayAsync", + ("hipMemcpy2DToArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy3D", ("hipMemcpy3D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DAsync", + ("hipMemcpy3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeer", + ("hipMemcpy3DPeer", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyArrayToArray", + ("hipMemcpyArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyFromArrayAsync", + ("hipMemcpyFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyFromSymbol", ("hipMemcpyFromSymbol", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyFromSymbolAsync", + ("hipMemcpyFromSymbolAsync", CONV_MEM, API_RUNTIME), + ), + ("cudaMemAdvise", ("hipMemAdvise", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetReadMostly", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetReadMostly", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetPreferredLocation", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseUnsetPreferredLocation", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseSetAccessedBy", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetAccessedBy", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeReadMostly", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributePreferredLocation", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemRangeAttributeAccessedBy", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeLastPrefetchLocation", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaMemcpyHostToHost", ("hipMemcpyHostToHost", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyHostToDevice", ("hipMemcpyHostToDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyDeviceToHost", ("hipMemcpyDeviceToHost", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyDeviceToDevice", + ("hipMemcpyDeviceToDevice", CONV_MEM, API_RUNTIME), + ), + ("cudaMemcpyDefault", ("hipMemcpyDefault", CONV_MEM, API_RUNTIME)), + ("cudaMemset", ("hipMemset", CONV_MEM, API_RUNTIME)), + ("cudaMemsetAsync", ("hipMemsetAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemset2D", ("hipMemset2D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemset2DAsync", + ("hipMemset2DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemset3D", ("hipMemset3D", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemset3DAsync", + ("hipMemset3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemGetInfo", ("hipMemGetInfo", CONV_MEM, API_RUNTIME)), + ("cudaDeviceGetDefaultMemPool", ("hipDeviceGetDefaultMemPool", CONV_MEM, API_RUNTIME)), + ("cudaMemAccessDesc", ("hipMemAccessDesc", CONV_MEM, API_RUNTIME)), + ("cudaMemAccessFlagsProtReadWrite", ("hipMemAccessFlagsProtReadWrite", CONV_MEM, API_RUNTIME)), + ("cudaMemLocationTypeDevice", ("hipMemLocationTypeDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReleaseThreshold", ("hipMemPoolAttrReleaseThreshold", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReservedMemCurrent", ("hipMemPoolAttrReservedMemCurrent", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrReservedMemHigh", ("hipMemPoolAttrReservedMemHigh", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrUsedMemCurrent", ("hipMemPoolAttrUsedMemCurrent", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolAttrUsedMemHigh", ("hipMemPoolAttrUsedMemHigh", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseAllowInternalDependencies", ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseAllowOpportunistic", ("hipMemPoolReuseAllowOpportunistic", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseFollowEventDependencies", ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_RUNTIME)), + ("cudaMemPool_t", ("hipMemPool_t", CONV_MEM, API_RUNTIME)), + ( + "cudaArrayGetInfo", + ("hipArrayGetInfo", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFreeMipmappedArray", + ("hipFreeMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetMipmappedArrayLevel", + ("hipGetMipmappedArrayLevel", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolAddress", + ("hipGetSymbolAddress", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolSize", + ("hipGetSymbolSize", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemPrefetchAsync", + ("hipMemPrefetchAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocHost", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMallocArray", ("hipMallocArray", CONV_MEM, API_RUNTIME)), + ("cudaMalloc", ("hipMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3D", ("hipMalloc3D", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3DArray", ("hipMalloc3DArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMallocManaged", + ("hipMallocManaged", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMallocMipmappedArray", + ("hipMallocMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocPitch", ("hipMallocPitch", CONV_MEM, API_RUNTIME)), + ("cudaFreeHost", ("hipHostFree", CONV_MEM, API_RUNTIME)), + ("cudaFreeArray", ("hipFreeArray", CONV_MEM, API_RUNTIME)), + ("cudaFree", ("hipFree", CONV_MEM, API_RUNTIME)), + ("cudaHostRegister", ("hipHostRegister", CONV_MEM, API_RUNTIME)), + ("cudaHostUnregister", ("hipHostUnregister", CONV_MEM, API_RUNTIME)), + ("cudaHostAlloc", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeHost", ("hipMemoryTypeHost", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeDevice", ("hipMemoryTypeDevice", CONV_MEM, API_RUNTIME)), + ("make_cudaExtent", ("make_hipExtent", CONV_MEM, API_RUNTIME)), + ("make_cudaPitchedPtr", ("make_hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("make_cudaPos", ("make_hipPos", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocDefault", ("hipHostMallocDefault", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocPortable", ("hipHostMallocPortable", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocMapped", ("hipHostMallocMapped", CONV_MEM, API_RUNTIME)), + ( + "cudaHostAllocWriteCombined", + ("hipHostMallocWriteCombined", CONV_MEM, API_RUNTIME), + ), + ("cudaHostGetFlags", ("hipHostGetFlags", CONV_MEM, API_RUNTIME)), + ("cudaHostRegisterDefault", ("hipHostRegisterDefault", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterPortable", + ("hipHostRegisterPortable", CONV_MEM, API_RUNTIME), + ), + ("cudaHostRegisterMapped", ("hipHostRegisterMapped", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterIoMemory", + ("hipHostRegisterIoMemory", CONV_MEM, API_RUNTIME), + ), + # ("warpSize", ("hipWarpSize", CONV_SPECIAL_FUNC, API_RUNTIME), (HIP actually uses warpSize...)), + ("cudaEventCreate", ("hipEventCreate", CONV_EVENT, API_RUNTIME)), + ( + "cudaEventCreateWithFlags", + ("hipEventCreateWithFlags", CONV_EVENT, API_RUNTIME), + ), + ("cudaEventDestroy", ("hipEventDestroy", CONV_EVENT, API_RUNTIME)), + ("cudaEventRecord", ("hipEventRecord", CONV_EVENT, API_RUNTIME)), + ("cudaEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_RUNTIME)), + ("cudaEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_RUNTIME)), + ("cudaEventQuery", ("hipEventQuery", CONV_EVENT, API_RUNTIME)), + ("cudaEventDefault", ("hipEventDefault", CONV_EVENT, API_RUNTIME)), + ("cudaEventBlockingSync", ("hipEventBlockingSync", CONV_EVENT, API_RUNTIME)), + ("cudaEventDisableTiming", ("hipEventDisableTiming", CONV_EVENT, API_RUNTIME)), + ("cudaEventInterprocess", ("hipEventInterprocess", CONV_EVENT, API_RUNTIME)), + ("cudaStreamCreate", ("hipStreamCreate", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamCreateWithFlags", + ("hipStreamCreateWithFlags", CONV_STREAM, API_RUNTIME), + ), + ( + "cudaStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamDestroy", ("hipStreamDestroy", CONV_STREAM, API_RUNTIME)), + ("cudaStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_RUNTIME)), + ("cudaStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_RUNTIME)), + ("cudaStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_RUNTIME)), + ("cudaStreamQuery", ("hipStreamQuery", CONV_STREAM, API_RUNTIME)), + ("cudaStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaCpuDeviceId", ("hipCpuDeviceId", CONV_TYPE, API_RUNTIME)), + ("cudaStreamDefault", ("hipStreamDefault", CONV_TYPE, API_RUNTIME)), + ("cudaStreamNonBlocking", ("hipStreamNonBlocking", CONV_TYPE, API_RUNTIME)), + ("cudaStreamGetCaptureInfo", ("hipStreamGetCaptureInfo", CONV_TYPE, API_RUNTIME)), + ("cudaStreamGetCaptureInfo_v2", ("hipStreamGetCaptureInfo_v2", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureStatus", ("hipStreamCaptureStatus", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureStatusActive", ("hipStreamCaptureStatusActive", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)), + ("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)), + ("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)), + ("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateWithFlags", ("hipGraphInstantiateWithFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateFlagAutoFreeOnLaunch", ("hipGraphInstantiateFlagAutoFreeOnLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDestroy", ("hipGraphDestroy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphExecDestroy", ("hipGraphExecDestroy", CONV_TYPE, API_RUNTIME)), + ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), + ("cudaGraphGetNodes", ("hipGraphGetNodes", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotPrint", ("hipGraphDebugDotPrint", CONV_TYPE, API_RUNTIME)), + ("cudaGraphDebugDotFlagsVerbose", ("hipGraphDebugDotFlagsVerbose", CONV_NUMERIC_LITERAL, API_RUNTIME)), + ("cudaGraphRetainUserObject", ("hipGraphRetainUserObject", CONV_TYPE, API_RUNTIME)), + ("cudaGraphUserObjectMove", ("hipGraphUserObjectMove", CONV_TYPE, API_RUNTIME)), + ("cudaUserObject_t", ("hipUserObject_t", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectCreate", ("hipUserObjectCreate", CONV_TYPE, API_RUNTIME)), + ("cudaUserObjectNoDestructorSync", ("hipUserObjectNoDestructorSync", CONV_TYPE, API_RUNTIME)), + ("cudaThreadExchangeStreamCaptureMode", ("hipThreadExchangeStreamCaptureMode", CONV_TYPE, API_RUNTIME)), + ("cudaStreamIsCapturing", ("hipStreamIsCapturing", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceSynchronize", ("hipDeviceSynchronize", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceReset", ("hipDeviceReset", CONV_DEVICE, API_RUNTIME)), + ("cudaSetDevice", ("hipSetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDevice", ("hipGetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDeviceCount", ("hipGetDeviceCount", CONV_DEVICE, API_RUNTIME)), + ("cudaChooseDevice", ("hipChooseDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaThreadExit", ("hipDeviceReset", CONV_THREAD, API_RUNTIME)), + ( + "cudaThreadGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadGetLimit", + ("hipThreadGetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaThreadSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadSetLimit", + ("hipThreadSetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaThreadSynchronize", ("hipDeviceSynchronize", CONV_THREAD, API_RUNTIME)), + ("cudaDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDevAttrMaxThreadsPerBlock", + ("hipDeviceAttributeMaxThreadsPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimX", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimY", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimZ", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimX", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimY", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimZ", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxSharedMemoryPerBlock", + ("hipDeviceAttributeMaxSharedMemoryPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxSharedMemoryPerBlockOptin", + ("hipDeviceAttributeMaxSharedMemoryPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTotalConstantMemory", + ("hipDeviceAttributeTotalConstantMemory", CONV_TYPE, API_RUNTIME), + ), + ("cudaDevAttrWarpSize", ("hipDeviceAttributeWarpSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrMaxPitch", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMaxRegistersPerBlock", + ("hipDeviceAttributeMaxRegistersPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrClockRate", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTextureAlignment", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGpuOverlap", + ("hipDeviceAttributeGpuOverlap", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMultiProcessorCount", + ("hipDeviceAttributeMultiprocessorCount", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrKernelExecTimeout", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIntegrated", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrCanMapHostMemory", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeMode", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DWidth", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DWidth", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DHeight", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidth", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeight", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepth", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredHeight", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSurfaceAlignment", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentKernels", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrEccEnabled", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDevAttrPciBusId", ("hipDeviceAttributePciBusId", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrPciDeviceId", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTccDriver", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMemoryClockRate", + ("hipDeviceAttributeMemoryClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrGlobalMemoryBusWidth", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrL2CacheSize", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxThreadsPerMultiProcessor", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrAsyncEngineCount", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrUnifiedAddressing", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherWidth", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherHeight", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidthAlt", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeightAlt", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepthAlt", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPciDomainId", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrTexturePitchAlignment", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapWidth", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DWidth", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DWidth", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DHeight", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DWidth", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DHeight", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DDepth", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredHeight", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLinearWidth", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearWidth", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearHeight", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearPitch", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedHeight", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeCapabilityMajor", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrComputeCapabilityMinor", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrStreamPrioritiesSupported", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGlobalL1CacheSupported", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrLocalL1CacheSupported", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSharedMemoryPerMultiprocessor", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + ), + ), + ( + "cudaDevAttrMaxRegistersPerMultiprocessor", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrManagedMemory", + ( + "hipDeviceAttributeManagedMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIsMultiGpuBoard", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMultiGpuBoardGroupID", + ( + "hipDeviceAttributeMultiGpuBoardGroupID", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrHostNativeAtomicSupported", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSingleToDoublePrecisionPerfRatio", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPageableMemoryAccess", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentManagedAccess", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputePreemptionSupported", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrCanUseHostPointerForRegisteredMem", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_RUNTIME), + ), + ( + "cudaHostGetDevicePointer", + ("hipHostGetDevicePointer", CONV_MEM, API_RUNTIME), + ), + ( + "cudaGetDeviceProperties", + ("hipGetDeviceProperties", CONV_DEVICE, API_RUNTIME), + ), + ("cudaDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDeviceGetByPCIBusId", + ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetStreamPriorityRange", + ( + "hipDeviceGetStreamPriorityRange", + CONV_DEVICE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaSetValidDevices", + ("hipSetValidDevices", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevP2PAttrPerformanceRank", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrAccessSupported", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrNativeAtomicSupported", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeDefault", + ("hipComputeModeDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusive", + ("hipComputeModeExclusive", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeProhibited", + ("hipComputeModeProhibited", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusiveProcess", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetDeviceFlags", + ("hipGetDeviceFlags", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSetDeviceFlags", ("hipSetDeviceFlags", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceScheduleAuto", ("hipDeviceScheduleAuto", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleSpin", ("hipDeviceScheduleSpin", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleYield", ("hipDeviceScheduleYield", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleMask", + ("hipDeviceScheduleMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMapHost", ("hipDeviceMapHost", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceLmemResizeToMax", + ("hipDeviceLmemResizeToMax", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMask", ("hipDeviceMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaDeviceSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaDeviceGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncAttributes", + ("hipFuncAttributes", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncAttributeMaxDynamicSharedMemorySize", + ("hipFuncAttributeMaxDynamicSharedMemorySize", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncAttributePreferredSharedMemoryCarveout", + ("hipFuncAttributePreferredSharedMemoryCarveout", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaFuncSetAttribute", + ("hipFuncSetAttribute", CONV_EXEC, API_RUNTIME), + ), + ("cudaFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferNone", + ("hipFuncCachePreferNone", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncCachePreferShared", + ("hipFuncCachePreferShared", CONV_CACHE, API_RUNTIME), + ), + ("cudaFuncCachePreferL1", ("hipFuncCachePreferL1", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferEqual", + ("hipFuncCachePreferEqual", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncGetAttributes", + ("hipFuncGetAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetParameterBuffer", + ("hipGetParameterBuffer", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForDevice", + ("hipSetDoubleForDevice", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForHost", + ("hipSetDoubleForHost", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaConfigureCall", + ("hipConfigureCall", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLaunch", ("hipLaunch", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaLaunchCooperativeKernel", + ("hipLaunchCooperativeKernel", CONV_EXEC, API_RUNTIME), + ), + ("cudaLaunchHostFunc", ("hipLaunchHostFunc", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaSetupArgument", + ("hipSetupArgument", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_RUNTIME)), + ( + "cudaRuntimeGetVersion", + ("hipRuntimeGetVersion", CONV_VERSION, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyMaxPotentialBlockSize", + ("hipOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_RUNTIME), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_RUNTIME, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMem", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMem", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_RUNTIME)), + ( + "cudaDeviceDisablePeerAccess", + ("hipDeviceDisablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ( + "cudaDeviceEnablePeerAccess", + ("hipDeviceEnablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ("cudaMemcpyPeerAsync", ("hipMemcpyPeerAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyPeer", ("hipMemcpyPeer", CONV_MEM, API_RUNTIME)), + ( + "cudaIpcMemLazyEnablePeerAccess", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceSetSharedMemConfig", + ("hipDeviceSetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetSharedMemConfig", + ("hipDeviceGetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeDefault", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeFourByte", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeEightByte", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaLimitStackSize", + ("hipLimitStackSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitPrintfFifoSize", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLimitMallocHeapSize", ("hipLimitMallocHeapSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaLimitDevRuntimeSyncDepth", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitDevRuntimePendingLaunchCount", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceGetLimit", ("hipDeviceGetLimit", CONV_DEVICE, API_RUNTIME)), + ( + "cudaProfilerInitialize", + ("hipProfilerInitialize", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaProfilerStart", ("hipProfilerStart", CONV_OTHER, API_RUNTIME)), + ("cudaProfilerStop", ("hipProfilerStop", CONV_OTHER, API_RUNTIME)), + ( + "cudaKeyValuePair", + ("hipKeyValuePair", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaCSV", ("hipCSV", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaReadModeElementType", ("hipReadModeElementType", CONV_TEX, API_RUNTIME)), + ( + "cudaReadModeNormalizedFloat", + ("hipReadModeNormalizedFloat", CONV_TEX, API_RUNTIME), + ), + ("cudaFilterModePoint", ("hipFilterModePoint", CONV_TEX, API_RUNTIME)), + ("cudaFilterModeLinear", ("hipFilterModeLinear", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture", ("hipBindTexture", CONV_TEX, API_RUNTIME)), + ("cudaUnbindTexture", ("hipUnbindTexture", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture2D", ("hipBindTexture2D", CONV_TEX, API_RUNTIME)), + ("cudaBindTextureToArray", ("hipBindTextureToArray", CONV_TEX, API_RUNTIME)), + ( + "cudaBindTextureToMipmappedArray", + ("hipBindTextureToMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureAlignmentOffset", + ("hipGetTextureAlignmentOffset", CONV_TEX, API_RUNTIME), + ), + ("cudaGetTextureReference", ("hipGetTextureReference", CONV_TEX, API_RUNTIME)), + ( + "cudaChannelFormatKindSigned", + ("hipChannelFormatKindSigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindUnsigned", + ("hipChannelFormatKindUnsigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindFloat", + ("hipChannelFormatKindFloat", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindNone", + ("hipChannelFormatKindNone", CONV_TEX, API_RUNTIME), + ), + ("cudaCreateChannelDesc", ("hipCreateChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaGetChannelDesc", ("hipGetChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypeArray", ("hipResourceTypeArray", CONV_TEX, API_RUNTIME)), + ( + "cudaResourceTypeMipmappedArray", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ("cudaResourceTypeLinear", ("hipResourceTypeLinear", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypePitch2D", ("hipResourceTypePitch2D", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatNone", ("hipResViewFormatNone", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedChar1", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar2", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar4", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar1", + ("hipResViewFormatSignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar2", + ("hipResViewFormatSignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar4", + ("hipResViewFormatSignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort1", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort2", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort4", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort1", + ("hipResViewFormatSignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort2", + ("hipResViewFormatSignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort4", + ("hipResViewFormatSignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt1", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt2", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt4", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt1", + ("hipResViewFormatSignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt2", + ("hipResViewFormatSignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt4", + ("hipResViewFormatSignedInt4", CONV_TEX, API_RUNTIME), + ), + ("cudaResViewFormatHalf1", ("hipResViewFormatHalf1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf2", ("hipResViewFormatHalf2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf4", ("hipResViewFormatHalf4", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat1", ("hipResViewFormatFloat1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat2", ("hipResViewFormatFloat2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat4", ("hipResViewFormatFloat4", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedBlockCompressed1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_RUNTIME), + ), + ("cudaAddressModeWrap", ("hipAddressModeWrap", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeClamp", ("hipAddressModeClamp", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeMirror", ("hipAddressModeMirror", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeBorder", ("hipAddressModeBorder", CONV_TEX, API_RUNTIME)), + ("cudaCreateTextureObject", ("hipCreateTextureObject", CONV_TEX, API_RUNTIME)), + ( + "cudaDestroyTextureObject", + ("hipDestroyTextureObject", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceDesc", + ("hipGetTextureObjectResourceDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceViewDesc", + ("hipGetTextureObjectResourceViewDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectTextureDesc", + ("hipGetTextureObjectTextureDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaBindSurfaceToArray", + ("hipBindSurfaceToArray", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceReference", + ("hipGetSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeZero", + ("hipBoundaryModeZero", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeClamp", + ("hipBoundaryModeClamp", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeTrap", + ("hipBoundaryModeTrap", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeForced", + ("hipFormatModeForced", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeAuto", + ("hipFormatModeAuto", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaCreateSurfaceObject", + ("hipCreateSurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDestroySurfaceObject", + ("hipDestroySurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceObjectResourceDesc", + ( + "hipGetSurfaceObjectResourceDesc", + CONV_SURFACE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaIpcCloseMemHandle", ("hipIpcCloseMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetEventHandle", ("hipIpcGetEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetMemHandle", ("hipIpcGetMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenEventHandle", ("hipIpcOpenEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenMemHandle", ("hipIpcOpenMemHandle", CONV_DEVICE, API_RUNTIME)), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveX", + ( + "hipGraphicsCubeFacePositiveX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeX", + ( + "hipGraphicsCubeFaceNegativeX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveY", + ( + "hipGraphicsCubeFacePositiveY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeY", + ( + "hipGraphicsCubeFaceNegativeY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveZ", + ( + "hipGraphicsCubeFacePositiveZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeZ", + ( + "hipGraphicsCubeFaceNegativeZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsNone", + ("hipGraphicsMapFlagsNone", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlagsReadOnly", + ( + "hipGraphicsMapFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsWriteDiscard", + ( + "hipGraphicsMapFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsNone", + ( + "hipGraphicsRegisterFlagsNone", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsReadOnly", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsWriteDiscard", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsSurfaceLoadStore", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsTextureGather", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLDeviceListAll", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListCurrentFrame", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListNextFrame", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsNone", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsReadOnly", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapFlagsWriteDiscard", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapBufferObject", + ("hipGLMapBufferObject__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetGLDevice", + ("hipGLSetGLDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListAll", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListCurrentFrame", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9DeviceListNextFrame", + ( + "HIP_D3D9_DEVICE_LIST_NEXT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9SetDirect3DDevice", + ("hipD3D9SetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D9RegisterResource", + ( + "hipGraphicsD3D9RegisterResource", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlagsNone", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_NONE", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsReadOnly", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsWriteDiscard", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9RegisterFlagsNone", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlagsArray", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPointer", + ( + "hipD3D9ResourceGetMappedPointer", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListAll", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListCurrentFrame", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10DeviceListNextFrame", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsNone", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsReadOnly", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsWriteDiscard", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10RegisterFlagsNone", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlagsArray", + ( + "HIP_D3D10_REGISTER_FLAGS_ARRAY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetMappedArray", + ( + "hipD3D10ResourceGetMappedArray", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPitch", + ( + "hipD3D10ResourceGetMappedPitch", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10SetDirect3DDevice", + ("hipD3D10SetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListAll", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListCurrentFrame", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11DeviceListNextFrame", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaVDPAUSetVDPAUDevice", + ("hipVDPAUSetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerAcquireFrame", + ( + "hipEGLStreamConsumerAcquireFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerReleaseFrame", + ( + "hipEGLStreamConsumerReleaseFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerPresentFrame", + ( + "hipEGLStreamProducerPresentFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cublasInit", ("hipblasInit", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasShutdown", + ("hipblasShutdown", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVersion", + ("hipblasGetVersion", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetError", + ("hipblasGetError", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAlloc", ("hipblasAlloc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasFree", ("hipblasFree", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSetKernelStream", + ("hipblasSetKernelStream", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetAtomicsMode", + ("hipblasGetAtomicsMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetAtomicsMode", + ("hipblasSetAtomicsMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetMathMode", + ("hipblasGetMathMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMathMode", + ("hipblasSetMathMode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("CUBLAS_OP_N", ("HIPBLAS_OP_N", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_OP_T", + ("HIPBLAS_OP_T", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_OP_C", + ("HIPBLAS_OP_C", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_SUCCESS", + ("HIPBLAS_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_INITIALIZED", + ("HIPBLAS_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ALLOC_FAILED", + ("HIPBLAS_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INVALID_VALUE", + ("HIPBLAS_STATUS_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_MAPPING_ERROR", + ("HIPBLAS_STATUS_MAPPING_ERROR", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_EXECUTION_FAILED", + ("HIPBLAS_STATUS_EXECUTION_FAILED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INTERNAL_ERROR", + ("HIPBLAS_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_SUPPORTED", + ("HIPBLAS_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ARCH_MISMATCH", + ("HIPBLAS_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_LOWER", + ("HIPBLAS_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_UPPER", + ("HIPBLAS_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_DIAG_NON_UNIT", + ("HIPBLAS_DIAG_NON_UNIT", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ("CUBLAS_DIAG_UNIT", ("HIPBLAS_DIAG_UNIT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_LEFT", ("HIPBLAS_SIDE_LEFT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_RIGHT", ("HIPBLAS_SIDE_RIGHT", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_POINTER_MODE_HOST", + ("HIPBLAS_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_POINTER_MODE_DEVICE", + ("HIPBLAS_POINTER_MODE_DEVICE", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_ATOMICS_NOT_ALLOWED", + ( + "HIPBLAS_ATOMICS_NOT_ALLOWED", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_ATOMICS_ALLOWED", + ( + "HIPBLAS_ATOMICS_ALLOWED", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_FLOAT", + ( + "HIPBLAS_DATA_FLOAT", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_DOUBLE", + ( + "HIPBLAS_DATA_DOUBLE", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_HALF", + ("HIPBLAS_DATA_HALF", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "CUBLAS_DATA_INT8", + ("HIPBLAS_DATA_INT8", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ("CUBLAS_GEMM_DEFAULT", ("HIPBLAS_GEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_GEMM_DEFAULT_TENSOR_OP", ("HIPBLAS_GEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_BLAS)), + ("cublasCreate", ("hipblasCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy", ("hipblasDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetVector", ("hipblasSetVector", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetVector", ("hipblasGetVector", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSetVectorAsync", + ("hipblasSetVectorAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVectorAsync", + ("hipblasGetVectorAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetMatrix", ("hipblasSetMatrix", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetMatrix", ("hipblasGetMatrix", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetMatrixAsync", + ("hipblasGetMatrixAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMatrixAsync", + ("hipblasSetMatrixAsync", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasXerbla", ("hipblasXerbla", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSnrm2", ("hipblasSnrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2", ("hipblasDnrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasScnrm2", ("hipblasScnrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDznrm2", ("hipblasDznrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasNrm2Ex", + ("hipblasNrm2Ex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdot", ("hipblasSdot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSdotBatched", + ("hipblasSdotBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDdot", ("hipblasDdot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDdotBatched", + ("hipblasDdotBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCdotu", ("hipblasCdotu", CONV_MATH_FUNC, API_BLAS)), + ("cublasCdotc", ("hipblasCdotc", CONV_MATH_FUNC, API_BLAS)), + ("cublasZdotu", ("hipblasZdotu", CONV_MATH_FUNC, API_BLAS)), + ("cublasZdotc", ("hipblasZdotc", CONV_MATH_FUNC, API_BLAS)), + ("cublasSscal", ("hipblasSscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSscalBatched", + ("hipblasSscalBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDscal", ("hipblasDscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDscalBatched", + ("hipblasDscalBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCscal", ("hipblasCscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsscal", ("hipblasCsscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZscal", ("hipblasZscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdscal", ("hipblasZdscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy", ("hipblasSaxpy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSaxpyBatched", + ("hipblasSaxpyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDaxpy", ("hipblasDaxpy", CONV_MATH_FUNC, API_BLAS)), + ("cublasCaxpy", ("hipblasCaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZaxpy", ("hipblasZaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasScopy", ("hipblasScopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScopyBatched", + ("hipblasScopyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDcopy", ("hipblasDcopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDcopyBatched", + ("hipblasDcopyBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCcopy", ("hipblasCcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZcopy", ("hipblasZcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSswap", ("hipblasSswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap", ("hipblasDswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasCswap", ("hipblasCswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZswap", ("hipblasZswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamax", ("hipblasIsamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax", ("hipblasIdamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamax", ("hipblasIcamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamax", ("hipblasIzamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamin", ("hipblasIsamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin", ("hipblasIdamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamin", ("hipblasIcamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamin", ("hipblasIzamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSasum", ("hipblasSasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSasumBatched", + ("hipblasSasumBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDasum", ("hipblasDasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDasumBatched", + ("hipblasDasumBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScasum", ("hipblasScasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDzasum", ("hipblasDzasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrot", ("hipblasSrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot", ("hipblasDrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot", ("hipblasCrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsrot", ("hipblasCsrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrot", ("hipblasZrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdrot", ("hipblasZdrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotg", ("hipblasSrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotg", ("hipblasDrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrotg", ("hipblasCrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrotg", ("hipblasZrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotm", ("hipblasSrotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotm", ("hipblasDrotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotmg", ("hipblasSrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotmg", ("hipblasDrotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgemv", ("hipblasSgemv", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSgemvBatched", + ("hipblasSgemvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDgemv", ("hipblasDgemv", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemv", ("hipblasCgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgemv", ("hipblasZgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgbmv", ("hipblasSgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDgbmv", ("hipblasDgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgbmv", ("hipblasCgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgbmv", ("hipblasZgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrmv", ("hipblasStrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmv", ("hipblasDtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmv", ("hipblasCtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmv", ("hipblasZtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbmv", ("hipblasStbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbmv", ("hipblasDtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbmv", ("hipblasCtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbmv", ("hipblasZtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpmv", ("hipblasStpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpmv", ("hipblasDtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpmv", ("hipblasCtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpmv", ("hipblasZtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsv", ("hipblasStrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrsv", ("hipblasDtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrsv", ("hipblasCtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsv", ("hipblasZtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpsv", ("hipblasStpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpsv", ("hipblasDtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpsv", ("hipblasCtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpsv", ("hipblasZtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbsv", ("hipblasStbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbsv", ("hipblasDtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbsv", ("hipblasCtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbsv", ("hipblasZtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymv", ("hipblasSsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymv", ("hipblasDsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymv", ("hipblasCsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymv", ("hipblasZsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemv", ("hipblasChemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemv", ("hipblasZhemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsbmv", ("hipblasSsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsbmv", ("hipblasDsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChbmv", ("hipblasChbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhbmv", ("hipblasZhbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspmv", ("hipblasSspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspmv", ("hipblasDspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpmv", ("hipblasChpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpmv", ("hipblasZhpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSger", ("hipblasSger", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger", ("hipblasDger", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeru", ("hipblasCgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgerc", ("hipblasCgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeru", ("hipblasZgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgerc", ("hipblasZgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr", ("hipblasSsyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasDsyr", ("hipblasDsyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasCher", ("hipblasCher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher", ("hipblasZher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr", ("hipblasSspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr", ("hipblasDspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr", ("hipblasChpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr", ("hipblasZhpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2", ("hipblasSsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2", ("hipblasDsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2", ("hipblasCher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2", ("hipblasZher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr2", ("hipblasSspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr2", ("hipblasDspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr2", ("hipblasChpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr2", ("hipblasZhpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgemmBatched", + ("hipblasSgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgemmBatched", + ("hipblasDgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasHgemmBatched", + ("hipblasHgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgemmStridedBatched", + ("hipblasSgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasDgemmStridedBatched", + ("hipblasDgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasHgemmStridedBatched", + ("hipblasHgemmStridedBatched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasCgemmBatched", + ("hipblasCgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mBatched", + ("hipblasCgemm3mBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemmBatched", + ("hipblasZgemmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemmStridedBatched", + ( + "hipblasCgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasCgemm3mStridedBatched", + ( + "hipblasCgemm3mStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasZgemmStridedBatched", + ( + "hipblasZgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasHgemmStridedBatched", + ( + "hipblasHgemmStridedBatched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cublasSgemm", ("hipblasSgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm", ("hipblasDgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemm", ("hipblasCgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasZgemm", ("hipblasZgemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasHgemm", ("hipblasHgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasSsyrk", ("hipblasSsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrk", ("hipblasDsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrk", ("hipblasCsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrk", ("hipblasZsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherk", ("hipblasCherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherk", ("hipblasZherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2k", ("hipblasSsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2k", ("hipblasDsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr2k", ("hipblasCsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr2k", ("hipblasZyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyrkx", ("hipblasSsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrkx", ("hipblasDsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrkx", ("hipblasCsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrkx", ("hipblasZsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2k", ("hipblasCher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2k", ("hipblasZher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherkx", ("hipblasCherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherkx", ("hipblasZherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymm", ("hipblasSsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymm", ("hipblasDsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymm", ("hipblasCsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymm", ("hipblasZsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemm", ("hipblasChemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemm", ("hipblasZhemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsm", ("hipblasStrsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDtrsm", ("hipblasDtrsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCtrsm", ("hipblasCtrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsm", ("hipblasZtrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasStrsmBatched", + ("hipblasStrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("hipblasDtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("hipblasCtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("hipblasZtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasStrmm", ("hipblasStrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmm", ("hipblasDtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmm", ("hipblasCtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmm", ("hipblasZtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgeam", ("hipblasSgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgeam", ("hipblasDgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeam", ("hipblasCgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeam", ("hipblasZgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgetrfBatched", + ("hipblasSgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrfBatched", + ("hipblasDgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrfBatched", + ("hipblasCgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrfBatched", + ("hipblasZgetrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetriBatched", + ("hipblasSgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetriBatched", + ("hipblasDgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetriBatched", + ("hipblasCgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetriBatched", + ("hipblasZgetriBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetrsBatched", + ("hipblasSgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrsBatched", + ("hipblasDgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrsBatched", + ("hipblasCgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrsBatched", + ("hipblasZgetrsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsmBatched", + ("hipblasStrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("hipblasDtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("hipblasCtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("hipblasZtrsmBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSmatinvBatched", + ("hipblasSmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDmatinvBatched", + ("hipblasDmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCmatinvBatched", + ("hipblasCmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZmatinvBatched", + ("hipblasZmatinvBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgeqrfBatched", + ("hipblasSgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgeqrfBatched", + ("hipblasDgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgeqrfBatched", + ("hipblasCgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeqrfBatched", + ("hipblasZgeqrfBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgelsBatched", + ("hipblasSgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgelsBatched", + ("hipblasDgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgelsBatched", + ("hipblasCgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgelsBatched", + ("hipblasZgelsBatched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdgmm", ("hipblasSdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDdgmm", ("hipblasDdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCdgmm", ("hipblasCdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdgmm", ("hipblasZdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpttr", ("hipblasStpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpttr", ("hipblasDtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpttr", ("hipblasCtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpttr", ("hipblasZtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrttp", ("hipblasStrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrttp", ("hipblasDtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrttp", ("hipblasCtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrttp", ("hipblasZtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCreate_v2", ("hipblasCreate_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy_v2", ("hipblasDestroy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetVersion_v2", + ("hipblasGetVersion_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetWorkspace", ("hipblasSetWorkspace", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetStream", ("hipblasSetStream", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream", ("hipblasGetStream", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetStream_v2", ("hipblasSetStream_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream_v2", ("hipblasGetStream_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetPointerMode", + ("hipblasGetPointerMode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode", + ("hipblasSetPointerMode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasGetPointerMode_v2", + ("hipblasGetPointerMode_v2", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode_v2", + ("hipblasSetPointerMode_v2", CONV_MATH_FUNC, API_BLAS), + ), + ("cublasSgemv_v2", ("hipblasSgemv_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemv_v2", ("hipblasDgemv_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemv_v2", + ("hipblasCgemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemv_v2", + ("hipblasZgemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgbmv_v2", + ("hipblasSgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgbmv_v2", + ("hipblasDgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgbmv_v2", + ("hipblasCgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgbmv_v2", + ("hipblasZgbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmv_v2", + ("hipblasStrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmv_v2", + ("hipblasDtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmv_v2", + ("hipblasCtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmv_v2", + ("hipblasZtrmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbmv_v2", + ("hipblasStbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbmv_v2", + ("hipblasDtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbmv_v2", + ("hipblasCtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbmv_v2", + ("hipblasZtbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpmv_v2", + ("hipblasStpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpmv_v2", + ("hipblasDtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpmv_v2", + ("hipblasCtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpmv_v2", + ("hipblasZtpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsv_v2", + ("hipblasStrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsv_v2", + ("hipblasDtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsv_v2", + ("hipblasCtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsv_v2", + ("hipblasZtrsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpsv_v2", + ("hipblasStpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpsv_v2", + ("hipblasDtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpsv_v2", + ("hipblasCtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpsv_v2", + ("hipblasZtpsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbsv_v2", + ("hipblasStbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbsv_v2", + ("hipblasDtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbsv_v2", + ("hipblasCtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbsv_v2", + ("hipblasZtbsv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymv_v2", + ("hipblasSsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymv_v2", + ("hipblasDsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymv_v2", + ("hipblasCsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymv_v2", + ("hipblasZsymv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemv_v2", + ("hipblasChemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemv_v2", + ("hipblasZhemv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsbmv_v2", + ("hipblasSsbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsbmv_v2", + ("hipblasDsbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChbmv_v2", + ("hipblasChbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhbmv_v2", + ("hipblasZhbmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspmv_v2", + ("hipblasSspmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspmv_v2", + ("hipblasDspmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpmv_v2", + ("hipblasChpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpmv_v2", + ("hipblasZhpmv_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSger_v2", ("hipblasSger_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger_v2", ("hipblasDger_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgeru_v2", + ("hipblasCgeru_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgerc_v2", + ("hipblasCergc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeru_v2", + ("hipblasZgeru_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgerc_v2", + ("hipblasZgerc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSsyr_v2", ("hipblasSsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr_v2", ("hipblasDsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr_v2", ("hipblasCsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr_v2", ("hipblasZsyr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher_v2", ("hipblasCher_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher_v2", ("hipblasZher_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr_v2", ("hipblasSspr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr_v2", ("hipblasDspr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr_v2", ("hipblasChpr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr_v2", ("hipblasZhpr_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSsyr2_v2", + ("hipblasSsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2_v2", + ("hipblasDsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2_v2", + ("hipblasCsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2_v2", + ("hipblasZsyr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2_v2", + ("hipblasCher2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2_v2", + ("hipblasZher2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspr2_v2", + ("hipblasSspr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspr2_v2", + ("hipblasDspr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpr2_v2", + ("hipblasChpr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpr2_v2", + ("hipblasZhpr2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSgemm_v2", ("hipblasSgemm_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm_v2", ("hipblasDgemm_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemm_v2", + ("hipblasCgemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3m", + ("hipblasCgemm3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mEx", + ("hipblasCgemm3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm_v2", + ("hipblasZgemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm3m", + ("hipblasZgemm3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgemmEx", + ("hipblasSgemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasGemmEx", ("hipblasGemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasGemmBatchedEx", + ("hipblasGemmBatchedEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGemmStridedBatchedEx", + ("hipblasGemmStridedBatchedEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemmEx", + ("hipblasCgemmEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasUint8gemmBias", + ("hipblasUint8gemmBias", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyrk_v2", + ("hipblasSsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyrk_v2", + ("hipblasDsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk_v2", + ("hipblasCsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyrk_v2", + ("hipblasZsyrk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrkEx", + ("hipblasCsyrkEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk3mEx", + ("hipblasCsyrk3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk_v2", + ("hipblasCherk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherkEx", + ("hipblasCherkEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk3mEx", + ("hipblasCherk3mEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZherk_v2", + ("hipblasZherk_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyr2k_v2", + ("hipblasSsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2k_v2", + ("hipblasDsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2k_v2", + ("hipblasCsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2k_v2", + ("hipblasZsyr2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2k_v2", + ("hipblasCher2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2k_v2", + ("hipblasZher2k_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymm_v2", + ("hipblasSsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymm_v2", + ("hipblasDsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymm_v2", + ("hipblasCsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymm_v2", + ("hipblasZsymm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemm_v2", + ("hipblasChemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemm_v2", + ("hipblasZhemm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsm_v2", + ("hipblasStrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsm_v2", + ("hipblasDtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsm_v2", + ("hipblasCtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsm_v2", + ("hipblasZtrsm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmm_v2", + ("hipblasStrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmm_v2", + ("hipblasDtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmm_v2", + ("hipblasCtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmm_v2", + ("hipblasZtrmm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSnrm2_v2", ("hipblasSnrm2_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2_v2", ("hipblasDnrm2_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScnrm2_v2", + ("hipblasScnrm2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDznrm2_v2", + ("hipblasDznrm2_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDotEx", ("hipblasDotEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDotcEx", ("hipblasDotcEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSdot_v2", ("hipblasSdot_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDdot_v2", ("hipblasDdot_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCdotu_v2", + ("hipblasCdotu_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCdotc_v2", + ("hipblasCdotc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotu_v2", + ("hipblasZdotu_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotc_v2", + ("hipblasZdotc_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScalEx", ("hipblasScalEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSscal_v2", ("hipblasSscal_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDscal_v2", ("hipblasDscal_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCscal_v2", + ("hipblasCscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsscal_v2", + ("hipblasCsscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZscal_v2", + ("hipblasZcsal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdscal_v2", + ("hipblasZdscal_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAxpyEx", ("hipblasAxpyEx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy_v2", ("hipblasSaxpy_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDaxpy_v2", ("hipblasDaxpy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCaxpy_v2", + ("hipblasCaxpy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZaxpy_v2", + ("hipblasZaxpy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScopy_v2", ("hipblasScopy_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDcopy_v2", ("hipblasDcopy_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCcopy_v2", + ("hipblasCcopy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZcopy_v2", + ("hipblasZcopy_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSswap_v2", ("hipblasSswap_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap_v2", ("hipblasDswap_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCswap_v2", + ("hipblasCswap_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZswap_v2", + ("hipblasZswap_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamax_v2", ("hipblasIsamax_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax_v2", ("hipblasIdamax_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamax_v2", + ("hipblasIcamax_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamax_v2", + ("hipblasIzamax_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamin_v2", ("hipblasIsamin_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin_v2", ("hipblasIdamin_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamin_v2", + ("hipblasIcamin_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamin_v2", + ("hipblasIzamin_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSasum_v2", ("hipblasSasum_v2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDasum_v2", ("hipblasDasum_v2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScasum_v2", + ("hipblasScasum_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDzasum_v2", + ("hipblasDzasum_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSrot_v2", ("hipblasSrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot_v2", ("hipblasDrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot_v2", ("hipblasCrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasCsrot_v2", + ("hipblasCsrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasZrot_v2", ("hipblasZrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasZdrot_v2", + ("hipblasZdrot_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotg_v2", + ("hipblasSrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotg_v2", + ("hipblasDrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCrotg_v2", + ("hipblasCrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZrotg_v2", + ("hipblasZrotg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotm_v2", + ("hipblasSrotm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotm_v2", + ("hipblasDrotm_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotmg_v2", + ("hipblasSrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotmg_v2", + ("hipblasDrotmg_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasComputeType_t", + ("hipblasComputeType_t", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32I", + ("HIPBLAS_COMPUTE_32I", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32F", + ("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_32F_FAST_TF32", + ("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS) + ), + ( + "CUBLAS_COMPUTE_64F", + ("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS) + ), + ("cublasLtEpilogue_t", ("hipblasLtEpilogue_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_DEFAULT", ("HIPBLASLT_EPILOGUE_DEFAULT", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU", ("HIPBLASLT_EPILOGUE_RELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_BIAS", ("HIPBLASLT_EPILOGUE_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_RELU_BIAS", ("HIPBLASLT_EPILOGUE_RELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU", ("HIPBLASLT_EPILOGUE_GELU", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_EPILOGUE_GELU_BIAS", ("HIPBLASLT_EPILOGUE_GELU_BIAS", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtHandle_t", ("hipblasLtHandle_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDesc_t", ("hipblasLtMatmulDesc_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescOpaque_t", ("hipblasLtMatmulDescOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescAttributes_t", ("hipblasLtMatmulDescAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSA", ("HIPBLASLT_MATMUL_DESC_TRANSA", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_TRANSB", ("HIPBLASLT_MATMUL_DESC_TRANSB", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_EPILOGUE", ("HIPBLASLT_MATMUL_DESC_EPILOGUE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_POINTER", ("HIPBLASLT_MATMUL_DESC_BIAS_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_D_SCALE_POINTER", ("HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_A_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_A_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_B_SCALE_MODE", ("HIPBLASLT_MATMUL_DESC_B_SCALE_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", ("HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_AMAX_D_POINTER", ("HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", ("HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayout_t", ("hipblasLtMatrixLayout_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutOpaque_t", ("hipblasLtMatrixLayoutOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutAttribute_t", ("hipblasLtMatrixLayoutAttribute_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutCreate", ("hipblasLtMatrixLayoutCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutDestroy", ("hipblasLtMatrixLayoutDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatrixLayoutSetAttribute", ("hipblasLtMatrixLayoutSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT", ("HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET", ("HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreference_t", ("hipblasLtMatmulPreference_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceOpaque_t", ("hipblasLtMatmulPreferenceOpaque_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceAttributes_t", ("hipblasLtMatmulPreferenceAttributes_t", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_SEARCH_MODE", ("HIPBLASLT_MATMUL_PREF_SEARCH_MODE", CONV_MATH_FUNC, API_BLAS)), + ("CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", ("HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgo_t", ("hipblasLtMatmulAlgo_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulHeuristicResult_t", ("hipblasLtMatmulHeuristicResult_t", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtCreate", ("hipblasLtCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtDestroy", ("hipblasLtDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescCreate", ("hipblasLtMatmulDescCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescDestroy", ("hipblasLtMatmulDescDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulDescSetAttribute", ("hipblasLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceCreate", ("hipblasLtMatmulPreferenceCreate", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceDestroy", ("hipblasLtMatmulPreferenceDestroy", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulPreferenceSetAttribute", ("hipblasLtMatmulPreferenceSetAttribute", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmulAlgoGetHeuristic", ("hipblasLtMatmulAlgoGetHeuristic", CONV_MATH_FUNC, API_BLAS)), + ("cublasLtMatmul", ("hipblasLtMatmul", CONV_MATH_FUNC, API_BLAS)), + ( + "CURAND_STATUS_SUCCESS", + ("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_VERSION_MISMATCH", + ("HIPRAND_STATUS_VERSION_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_NOT_INITIALIZED", + ("HIPRAND_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ALLOCATION_FAILED", + ("HIPRAND_STATUS_ALLOCATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_TYPE_ERROR", + ("HIPRAND_STATUS_TYPE_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_OUT_OF_RANGE", + ("HIPRAND_STATUS_OUT_OF_RANGE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_LENGTH_NOT_MULTIPLE", + ("HIPRAND_STATUS_LENGTH_NOT_MULTIPLE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED", + ( + "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED", + CONV_NUMERIC_LITERAL, + API_RAND, + ), + ), + ( + "CURAND_STATUS_LAUNCH_FAILURE", + ("HIPRAND_STATUS_LAUNCH_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_PREEXISTING_FAILURE", + ("HIPRAND_STATUS_PREEXISTING_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INITIALIZATION_FAILED", + ("HIPRAND_STATUS_INITIALIZATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ARCH_MISMATCH", + ("HIPRAND_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INTERNAL_ERROR", + ("HIPRAND_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ("CURAND_RNG_TEST", ("HIPRAND_RNG_TEST", CONV_NUMERIC_LITERAL, API_RAND)), + ( + "mtgp32dc_params_fast_11213", + ("mtgp32dc_params_fast_11213", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_DEFAULT", + ("HIPRAND_RNG_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_XORWOW", + ("HIPRAND_RNG_PSEUDO_XORWOW", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MRG32K3A", + ("HIPRAND_RNG_PSEUDO_MRG32K3A", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MTGP32", + ("HIPRAND_RNG_PSEUDO_MTGP32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MT19937", + ("HIPRAND_RNG_PSEUDO_MT19937", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_PHILOX4_32_10", + ("HIPRAND_RNG_PSEUDO_PHILOX4_32_10", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_DEFAULT", + ("HIPRAND_RNG_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL32", + ("HIPRAND_RNG_QUASI_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL64", + ("HIPRAND_RNG_QUASI_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "curand_ORDERING_PSEUDO_BEST", + ( + "HIPRAND_ORDERING_PSEUDO_BEST", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_DEFAULT", + ( + "HIPRAND_ORDERING_PSEUDO_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_SEEDED", + ( + "HIPRAND_ORDERING_PSEUDO_SEEDED", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_QUASI_DEFAULT", + ( + "HIPRAND_ORDERING_QUASI_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_CHOOSE_BEST", + ("HIPRAND_CHOOSE_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_ITR", + ("HIPRAND_ITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_KNUTH", + ("HIPRAND_KNUTH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_HITR", + ("HIPRAND_HITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_M1", ("HIPRAND_M1", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ("curand_M2", ("HIPRAND_M2", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ( + "curand_BINARY_SEARCH", + ("HIPRAND_BINARY_SEARCH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DISCRETE_GAUSS", + ("HIPRAND_DISCRETE_GAUSS", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_REJECTION", + ("HIPRAND_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEVICE_API", + ("HIPRAND_DEVICE_API", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_FAST_REJECTION", + ("HIPRAND_FAST_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_3RD", + ("HIPRAND_3RD", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEFINITION", + ("HIPRAND_DEFINITION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_POISSON", + ("HIPRAND_POISSON", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curandCreateGenerator", ("hiprandCreateGenerator", CONV_MATH_FUNC, API_RAND)), + ( + "curandCreateGeneratorHost", + ("hiprandCreateGeneratorHost", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandCreatePoissonDistribution", + ("hiprandCreatePoissonDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyDistribution", + ("hiprandDestroyDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyGenerator", + ("hiprandDestroyGenerator", CONV_MATH_FUNC, API_RAND), + ), + ("curandGenerate", ("hiprandGenerate", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateLogNormal", + ("hiprandGenerateLogNormal", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLogNormalDouble", + ("hiprandGenerateLogNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLongLong", + ("hiprandGenerateLongLong", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curandGenerateNormal", ("hiprandGenerateNormal", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateNormalDouble", + ("hiprandGenerateNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ("curandGeneratePoisson", ("hiprandGeneratePoisson", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateSeeds", ("hiprandGenerateSeeds", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateUniform", ("hiprandGenerateUniform", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateUniformDouble", + ("hiprandGenerateUniformDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGetDirectionVectors32", + ("hiprandGetDirectionVectors32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetDirectionVectors64", + ("hiprandGetDirectionVectors64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetProperty", + ("hiprandGetProperty", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetScrambleConstants32", + ( + "hiprandGetScrambleConstants32", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curandGetScrambleConstants64", + ( + "hiprandGetScrambleConstants64", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ("curandGetVersion", ("hiprandGetVersion", CONV_MATH_FUNC, API_RAND)), + ( + "curandSetGeneratorOffset", + ("hiprandSetGeneratorOffset", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetGeneratorOrdering", + ("hiprandSetGeneratorOrdering", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandSetPseudoRandomGeneratorSeed", + ("hiprandSetPseudoRandomGeneratorSeed", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetQuasiRandomGeneratorDimensions", + ("hiprandSetQuasiRandomGeneratorDimensions", CONV_MATH_FUNC, API_RAND), + ), + ("curandSetStream", ("hiprandSetStream", CONV_MATH_FUNC, API_RAND)), + ("curand", ("hiprand", CONV_DEVICE_FUNC, API_RAND)), + ("curand4", ("hiprand4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_init", ("hiprand_init", CONV_DEVICE_FUNC, API_RAND)), + ("curand_log_normal", ("hiprand_log_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal_double", + ("hiprand_log_normal_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal2", ("hiprand_log_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal2_double", + ("hiprand_log_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal4", ("hiprand_log_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal4_double", + ("hiprand_log_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_mtgp32_single", + ("hiprand_mtgp32_single", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_mtgp32_single_specific", + ( + "hiprand_mtgp32_single_specific", + CONV_DEVICE_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_mtgp32_specific", + ("hiprand_mtgp32_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_normal", ("hiprand_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curandMakeMTGP32Constants", + ("hiprandMakeMTGP32Constants", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curandMakeMTGP32KernelState", + ("hiprandMakeMTGP32KernelState", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal_double", ("hiprand_normal_double", CONV_DEVICE_FUNC, API_RAND)), + ("curand_normal2", ("hiprand_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal2_double", + ("hiprand_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal4", ("hiprand_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal4_double", + ("hiprand_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform", ("hiprand_uniform", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform_double", + ("hiprand_uniform_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_uniform2_double", + ("hiprand_uniform2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform4", ("hiprand_uniform4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform4_double", + ("hiprand_uniform4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_discrete", ("hiprand_discrete", CONV_DEVICE_FUNC, API_RAND)), + ("curand_discrete4", ("hiprand_discrete4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson", ("hiprand_poisson", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson4", ("hiprand_poisson4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_Philox4x32_10", + ("hiprand_Philox4x32_10", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("mtgp32_kernel_params", ("mtgp32_kernel_params_t", CONV_MATH_FUNC, API_RAND)), + ("CUFFT_FORWARD", ("HIPFFT_FORWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUFFT_INVERSE", ("HIPFFT_BACKWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUFFT_COMPATIBILITY_DEFAULT", + ( + "HIPFFT_COMPATIBILITY_DEFAULT", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cuComplex", ("hipComplex", CONV_TYPE, API_BLAS)), + ("cuDoubleComplex", ("hipDoubleComplex", CONV_TYPE, API_BLAS)), + ("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)), + ("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)), + ("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_PLAN", ("HIPFFT_INVALID_PLAN", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_ALLOC_FAILED", ("HIPFFT_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_TYPE", ("HIPFFT_INVALID_TYPE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_INVALID_VALUE", + ("HIPFFT_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INTERNAL_ERROR", + ("HIPFFT_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_EXEC_FAILED", ("HIPFFT_EXEC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_SETUP_FAILED", ("HIPFFT_SETUP_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_SIZE", ("HIPFFT_INVALID_SIZE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_UNALIGNED_DATA", + ("HIPFFT_UNALIGNED_DATA", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INCOMPLETE_PARAMETER_LIST", + ("HIPFFT_INCOMPLETE_PARAMETER_LIST", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INVALID_DEVICE", + ("HIPFFT_INVALID_DEVICE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_PARSE_ERROR", ("HIPFFT_PARSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_NO_WORKSPACE", ("HIPFFT_NO_WORKSPACE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_NOT_IMPLEMENTED", + ("HIPFFT_NOT_IMPLEMENTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_LICENSE_ERROR", + ("HIPFFT_LICENSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_NOT_SUPPORTED", + ("HIPFFT_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("cufftType_t", ("hipfftType_t", CONV_TYPE, API_FFT)), + ("cufftType", ("hipfftType", CONV_TYPE, API_FFT)), + ("CUFFT_R2C", ("HIPFFT_R2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2R", ("HIPFFT_C2R", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2C", ("HIPFFT_C2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_D2Z", ("HIPFFT_D2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2D", ("HIPFFT_Z2D", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2Z", ("HIPFFT_Z2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "cufftCompatibility_t", + ("hipfftCompatibility_t", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "cufftCompatibility", + ("hipfftCompatibility", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_COMPATIBILITY_FFTW_PADDING", + ( + "HIPFFT_COMPATIBILITY_FFTW_PADDING", + CONV_NUMERIC_LITERAL, + API_FFT, + HIP_UNSUPPORTED, + ), + ), + ("cufftReal", ("hipfftReal", CONV_TYPE, API_FFT)), + ("cufftDoubleReal", ("hipfftDoubleReal", CONV_TYPE, API_FFT)), + ("cufftComplex", ("hipfftComplex", CONV_TYPE, API_FFT)), + ("cufftDoubleComplex", ("hipfftDoubleComplex", CONV_TYPE, API_FFT)), + ("cufftHandle", ("hipfftHandle", CONV_TYPE, API_FFT)), + ("cufftPlan1d", ("hipfftPlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan2d", ("hipfftPlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan3d", ("hipfftPlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlanMany", ("hipfftPlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan1d", ("hipfftMakePlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan2d", ("hipfftMakePlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan3d", ("hipfftMakePlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany", ("hipfftMakePlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany64", ("hipfftMakePlanMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany64", ("hipfftGetSizeMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate1d", ("hipfftEstimate1d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate2d", ("hipfftEstimate2d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate3d", ("hipfftEstimate3d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimateMany", ("hipfftEstimateMany", CONV_MATH_FUNC, API_FFT)), + ("cufftCreate", ("hipfftCreate", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize1d", ("hipfftGetSize1d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize2d", ("hipfftGetSize2d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize3d", ("hipfftGetSize3d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany", ("hipfftGetSizeMany", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize", ("hipfftGetSize", CONV_MATH_FUNC, API_FFT)), + ("cufftSetWorkArea", ("hipfftSetWorkArea", CONV_MATH_FUNC, API_FFT)), + ( + "cufftSetAutoAllocation", + ("hipfftSetAutoAllocation", CONV_MATH_FUNC, API_FFT), + ), + ("cufftXtExec", ("hipfftXtExec", CONV_MATH_FUNC, API_FFT)), + ("cufftXtMakePlanMany", ("hipfftXtMakePlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftExecC2C", ("hipfftExecC2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecR2C", ("hipfftExecR2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecC2R", ("hipfftExecC2R", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2Z", ("hipfftExecZ2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecD2Z", ("hipfftExecD2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2D", ("hipfftExecZ2D", CONV_MATH_FUNC, API_FFT)), + ("cufftSetStream", ("hipfftSetStream", CONV_MATH_FUNC, API_FFT)), + ("cufftDestroy", ("hipfftDestroy", CONV_MATH_FUNC, API_FFT)), + ("cufftGetVersion", ("hipfftGetVersion", CONV_MATH_FUNC, API_FFT)), + ( + "cufftGetProperty", + ("hipfftGetProperty", CONV_MATH_FUNC, API_FFT, HIP_UNSUPPORTED), + ), + ("nvrtcResult", ("hiprtcResult", CONV_TYPE, API_RTC)), + ("NVRTC_SUCCESS", ("HIPRTC_SUCCESS", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_OUT_OF_MEMORY", + ("HIPRTC_ERROR_OUT_OF_MEMORY", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_PROGRAM_CREATION_FAILURE", + ("HIPRTC_ERROR_PROGRAM_CREATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_INPUT", + ("HIPRTC_ERROR_INVALID_INPUT", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_PROGRAM", + ("HIPRTC_ERROR_INVALID_PROGRAM", CONV_TYPE, API_RTC), + ), + ("NVRTC_ERROR_COMPILATION", ("HIPRTC_ERROR_COMPILATION", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_BUILTIN_OPERATION_FAILURE", + ("HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", + ("HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID", + ("HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INTERNAL_ERROR", + ("HIPRTC_ERROR_INTERNAL_ERROR", CONV_TYPE, API_RTC), + ), + ("nvrtcGetErrorString", ("hiprtcGetErrorString", CONV_JIT, API_RTC)), + ("nvrtcVersion", ("hiprtcVersion", CONV_JIT, API_RTC)), + ("nvrtcProgram", ("hiprtcProgram", CONV_TYPE, API_RTC)), + ("nvrtcAddNameExpression", ("hiprtcAddNameExpression", CONV_JIT, API_RTC)), + ("nvrtcCompileProgram", ("hiprtcCompileProgram", CONV_JIT, API_RTC)), + ("nvrtcCreateProgram", ("hiprtcCreateProgram", CONV_JIT, API_RTC)), + ("nvrtcDestroyProgram", ("hiprtcDestroyProgram", CONV_JIT, API_RTC)), + ("nvrtcGetLoweredName", ("hiprtcGetLoweredName", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLog", ("hiprtcGetProgramLog", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLogSize", ("hiprtcGetProgramLogSize", CONV_JIT, API_RTC)), + ("nvrtcGetPTX", ("hiprtcGetCode", CONV_JIT, API_RTC)), + ("nvrtcGetPTXSize", ("hiprtcGetCodeSize", CONV_JIT, API_RTC)), + ("thrust::cuda", ("thrust::hip", CONV_MATH_FUNC, API_BLAS)), + ( + "cudaCpuDeviceId", + ("hipCpuDeviceId", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + # The caffe2 directory does a string match; pytorch does a word-boundary match. + # Patterns such as 'cub::' will not match for pytorch. + # We list all current uses of cub symbols for this reason. + ("cub::", ("hipcub::", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgMax", ("hipcub::ArgMax", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgMin", ("hipcub::ArgMin", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_SCAN_WARP_SCANS", ("hipcub::BLOCK_SCAN_WARP_SCANS", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_REDUCE_WARP_REDUCTIONS", ("hipcub::BLOCK_REDUCE_WARP_REDUCTIONS", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_STORE_WARP_TRANSPOSE", ("hipcub::BLOCK_STORE_WARP_TRANSPOSE", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_LOAD_DIRECT", ("hipcub::BLOCK_LOAD_DIRECT", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BLOCK_STORE_DIRECT", ("hipcub::BLOCK_STORE_DIRECT", CONV_SPECIAL_FUNC, API_RUNTIME)), + ( + "cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY", + ("hipcub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY", CONV_SPECIAL_FUNC, API_RUNTIME) + ), + ("cub::BlockReduce", ("hipcub::BlockReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockScan", ("hipcub::BlockScan", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockLoad", ("hipcub::BlockLoad", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockStore", ("hipcub::BlockStore", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockRakingLayout", ("hipcub::BlockRakingLayout", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockRadixSort", ("hipcub::BlockRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Uninitialized", ("hipcub::Uninitialized", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::RowMajorTid", ("hipcub::RowMajorTid", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CachingDeviceAllocator", ("hipcub::CachingDeviceAllocator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CountingInputIterator", ("hipcub::CountingInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceRadixSort", ("hipcub::DeviceRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceReduce", ("hipcub::DeviceReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceRunLengthEncode", ("hipcub::DeviceRunLengthEncode", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceScan", ("hipcub::DeviceScan", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSegmentedRadixSort", ("hipcub::DeviceSegmentedRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSegmentedReduce", ("hipcub::DeviceSegmentedReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::DeviceSelect", ("hipcub::DeviceSelect", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::FpLimits", ("hipcub::FpLimits", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::KeyValuePair", ("hipcub::KeyValuePair", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Max", ("hipcub::Max", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Min", ("hipcub::Min", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Sum", ("hipcub::Sum", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::Log2", ("hipcub::Log2", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::LaneId", ("hipcub::LaneId", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::WarpMask", ("hipcub::WarpMask", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ShuffleIndex", ("hipcub::ShuffleIndex", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ShuffleDown", ("hipcub::ShuffleDown", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::ArgIndexInputIterator", ("hipcub::ArgIndexInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::TransformInputIterator", ("hipcub::TransformInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::WarpReduce", ("hipcub::WarpReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::CTA_SYNC", ("hipcub::CTA_SYNC", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("nvtxMark", ("roctxMark", CONV_OTHER, API_ROCTX)), + ("nvtxMarkA", ("roctxMarkA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePushA", ("roctxRangePushA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePop", ("roctxRangePop", CONV_OTHER, API_ROCTX)), + ("nvtxRangeStartA", ("roctxRangeStartA", CONV_OTHER, API_ROCTX)), + ("nvtxRangeEnd", ("roctxRangeStop", CONV_OTHER, API_ROCTX)), + ("nvtxRangeId_t", ("int", CONV_OTHER, API_ROCTX)), + ("nvmlReturn_t", ("rsmi_status_t", CONV_OTHER, API_ROCMSMI)), + ("NVML_SUCCESS", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_P2P_CAPS_INDEX_READ", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_P2P_STATUS_OK", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)), + ("NVML_ERROR_INSUFFICIENT_SIZE", ("RSMI_STATUS_INSUFFICIENT_SIZE", CONV_OTHER, API_ROCMSMI)), + ("nvmlDevice_t", ("uint32_t", CONV_OTHER, API_ROCMSMI)), + ("nvmlGpuP2PStatus_t", ("bool", CONV_OTHER, API_ROCMSMI)), + ("nvmlProcessInfo_t", ("rsmi_process_info_t", CONV_OTHER, API_ROCMSMI)), + ("nvmlGpuP2PCapsIndex_t", ("uint32_t", CONV_OTHER, API_ROCMSMI)), + ] +) + +CUDA_SPECIAL_MAP = collections.OrderedDict( + [ + # SPARSE + ("cusparseStatus_t", ("hipsparseStatus_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseHandle_t", ("hipsparseHandle_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cuComplex", ("hipComplex", CONV_TYPE, API_SPECIAL)), + ("cuDoubleComplex", ("hipDoubleComplex", CONV_TYPE, API_SPECIAL)), + ( + "CUSPARSE_POINTER_MODE_HOST", + ("HIPSPARSE_POINTER_MODE_HOST", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPECIAL)), + ( + "cusparseCreateMatDescr", + ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseCreate", ("hipsparseCreate", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseDestroyMatDescr", + ("hipsparseDestroyMatDescr", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseDestroy", ("hipsparseDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcoo2csr", ("hipsparseXcoo2csr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseMatDescr_t", ("hipsparseMatDescr_t", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDiagType_t", ("hipsparseDiagType_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_DIAG_TYPE_UNIT", ("HIPSPARSE_DIAG_TYPE_UNIT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_DIAG_TYPE_NON_UNIT", ("HIPSPARSE_DIAG_TYPE_NON_UNIT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSetMatDiagType", ("hipsparseSetMatDiagType", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseFillMode_t", ("hipsparseFillMode_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_FILL_MODE_UPPER", ("HIPSPARSE_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_FILL_MODE_LOWER", ("HIPSPARSE_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSetMatFillMode", ("hipsparseSetMatFillMode", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDirection_t", ("hipsparseDirection_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_DIRECTION_ROW", ("HIPSPARSE_DIRECTION_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_DIRECTION_COLUMN", ("HIPSPARSE_DIRECTION_COLUMN", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseSolvePolicy_t", ("hipsparseSolvePolicy_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_SOLVE_POLICY_NO_LEVEL", ("HIPSPARSE_SOLVE_POLICY_NO_LEVEL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SOLVE_POLICY_USE_LEVEL", ("HIPSPARSE_SOLVE_POLICY_USE_LEVEL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseCreateBsrsv2Info", ("hipsparseCreateBsrsv2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateBsrsm2Info", ("hipsparseCreateBsrsm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyBsrsv2Info", ("hipsparseDestroyBsrsv2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyBsrsm2Info", ("hipsparseDestroyBsrsm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrmm", ("hipsparseSbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrmm", ("hipsparseDbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrmm", ("hipsparseCbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrmm", ("hipsparseZbsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrmv", ("hipsparseSbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrmv", ("hipsparseDbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrmv", ("hipsparseCbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrmv", ("hipsparseZbsrmv", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_bufferSize", ("hipsparseSbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_bufferSize", ("hipsparseDbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_bufferSize", ("hipsparseCbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_bufferSize", ("hipsparseZbsrsv2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_analysis", ("hipsparseSbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_analysis", ("hipsparseDbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_analysis", ("hipsparseCbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_analysis", ("hipsparseZbsrsv2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsv2_solve", ("hipsparseSbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsv2_solve", ("hipsparseDbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsv2_solve", ("hipsparseCbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsv2_solve", ("hipsparseZbsrsv2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_bufferSize", ("hipsparseSbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_bufferSize", ("hipsparseDbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_bufferSize", ("hipsparseCbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_bufferSize", ("hipsparseZbsrsm2_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_analysis", ("hipsparseSbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_analysis", ("hipsparseDbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_analysis", ("hipsparseCbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_analysis", ("hipsparseZbsrsm2_analysis", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSbsrsm2_solve", ("hipsparseSbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDbsrsm2_solve", ("hipsparseDbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCbsrsm2_solve", ("hipsparseCbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZbsrsm2_solve", ("hipsparseZbsrsm2_solve", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrmm2", ("hipsparseScsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrmm2", ("hipsparseDcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrmm2", ("hipsparseCcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrmm2", ("hipsparseZcsrmm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrmm", ("hipsparseScsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrmm", ("hipsparseDcsrmm", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseXcsrsort_bufferSizeExt", + ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseCreateCsrgemm2Info", ("hipsparseCreateCsrgemm2Info", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseDestroyCsrgemm2Info", + ("hipsparseDestroyCsrgemm2Info", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseXcsrgemm2Nnz", ("hipsparseXcsrgemm2Nnz", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgemm2_bufferSizeExt", ("hipsparseDcsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgemm2_bufferSizeExt", ("hipsparseScsrgemm2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgemm2", ("hipsparseDcsrgemm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgemm2", ("hipsparseScsrgemm2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSetPointerMode", ("hipsparseSetPointerMode", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcsrgeam2Nnz", ("hipsparseXcsrgeam2Nnz", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgeam2_bufferSizeExt", ("hipsparseScsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgeam2_bufferSizeExt", ("hipsparseDcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrgeam2_bufferSizeExt", ("hipsparseCcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrgeam2_bufferSizeExt", ("hipsparseZcsrgeam2_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseScsrgeam2", ("hipsparseScsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDcsrgeam2", ("hipsparseDcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCcsrgeam2", ("hipsparseCcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseZcsrgeam2", ("hipsparseZcsrgeam2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXbsrsm2_zeroPivot", ("hipsparseXbsrsm2_zeroPivot", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseXbsrsv2_zeroPivot", ("hipsparseXbsrsv2_zeroPivot", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseXcoosort_bufferSizeExt", + ("hipsparseXcoosort_bufferSizeExt", CONV_MATH_FUNC, API_SPECIAL), + ), + ( + "cusparseXcoosortByRow", + ("hipsparseXcoosortByRow", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseSetStream", ("hipsparseSetStream", CONV_MATH_FUNC, API_SPECIAL)), + ( + "cusparseCreateIdentityPermutation", + ("hipsparseCreateIdentityPermutation", CONV_MATH_FUNC, API_SPECIAL), + ), + ( + "cusparseSetMatIndexBase", + ("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPECIAL), + ), + ("cusparseSetMatType", ("hipsparseSetMatType", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMV", ("hipsparseSpMV", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMV_bufferSize", ("hipsparseSpMV_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMM", ("hipsparseSpMM", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMM_bufferSize", ("hipsparseSpMM_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateDnMat", ("hipsparseCreateDnMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCsrSetStridedBatch", ("hipsparseCsrSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateDnVec", ("hipsparseCreateDnVec", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyDnMat", ("hipsparseDestroyDnMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroyDnVec", ("hipsparseDestroyDnVec", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDestroySpMat", ("hipsparseDestroySpMat", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_destroyDescr", ("hipsparseSpGEMM_destroyDescr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCoo", ("hipsparseCreateCoo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCreateCsr", ("hipsparseCreateCsr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_createDescr", ("hipsparseSpGEMM_createDescr", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseDnMatSetStridedBatch", ("hipsparseDnMatSetStridedBatch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_copy", ("hipsparseSpGEMM_copy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM_bufferSize", ("hipsparseSDDMM_bufferSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM_preprocess", ("hipsparseSDDMM_preprocess", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSDDMM", ("hipsparseSDDMM", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_compute", ("hipsparseSpGEMM_compute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpGEMM_workEstimation", ("hipsparseSpGEMM_workEstimation", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMatGetSize", ("hipsparseSpMatGetSize", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseCsrSetPointers", ("hipsparseCsrSetPointers", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseSpMVAlg_t", ("hipsparseSpMVAlg_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpMMAlg_t", ("hipsparseSpMMAlg_t", CONV_TYPE, API_SPECIAL)), + ("cusparseIndexType_t", ("hipsparseIndexType_t", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseMatDescr", ("hipsparseMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseDnMatDescr", ("hipsparseDnMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseDnVecDescr", ("hipsparseDnVecDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseSpMatDescr", ("hipsparseSpMatDescr", CONV_TYPE, API_SPECIAL)), + # Unsupported ("cusparseSpGEMMDescr", ("hipsparseSpGEMMDescr", CONV_TYPE, API_SPECIAL)), + ("cusparseDnMatDescr_t", ("hipsparseDnMatDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseDnVecDescr_t", ("hipsparseDnVecDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpMatDescr_t", ("hipsparseSpMatDescr_t", CONV_TYPE, API_SPECIAL)), + ("cusparseSpGEMMDescr_t", ("hipsparseSpGEMMDescr_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_INDEX_32I", ("HIPSPARSE_INDEX_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_INDEX_64I", ("HIPSPARSE_INDEX_64I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_MV_ALG_DEFAULT", ("HIPSPARSE_MV_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_MM_ALG_DEFAULT", ("HIPSPARSE_MM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_COO_ALG1", ("HIPSPARSE_SPMM_COO_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_COO_ALG2", ("HIPSPARSE_SPMM_COO_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_SPMM_CSR_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG2", ("HIPSPARSE_SPMM_CSR_ALG2", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG3", ("HIPSPARSE_SPMM_CSR_ALG3", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COOMV_ALG", ("HIPSPARSE_COOMV_ALG", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPMM_CSR_ALG1", ("HIPSPARSE_CSRMM_ALG1", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SPGEMM_DEFAULT", ("HIPSPARSE_SPGEMM_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_SDDMM_ALG_DEFAULT", ("HIPSPARSE_SDDMM_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ( + "CUSPARSE_STATUS_SUCCESS", + ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_NOT_INITIALIZED", + ("HIPSPARSE_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_ALLOC_FAILED", + ("HIPSPARSE_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_INVALID_VALUE", + ("HIPSPARSE_STATUS_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_MAPPING_ERROR", + ("HIPSPARSE_STATUS_MAPPING_ERROR", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_EXECUTION_FAILED", + ("HIPSPARSE_STATUS_EXECUTION_FAILED", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_INTERNAL_ERROR", + ("HIPSPARSE_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + ( + "HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + CONV_NUMERIC_LITERAL, + API_SPECIAL, + ), + ), + ( + "CUSPARSE_STATUS_ARCH_MISMATCH", + ("HIPSPARSE_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_STATUS_ZERO_PIVOT", + ("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_TRANSPOSE", + ("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_NON_TRANSPOSE", + ("HIPSPARSE_OPERATION_NON_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + ( + "HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + CONV_NUMERIC_LITERAL, + API_SPECIAL, + ), + ), + ( + "CUSPARSE_INDEX_BASE_ZERO", + ("HIPSPARSE_INDEX_BASE_ZERO", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_INDEX_BASE_ONE", + ("HIPSPARSE_INDEX_BASE_ONE", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUSPARSE_MATRIX_TYPE_GENERAL", + ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + # SparseLt + ("cuSPARSELt", ("hipSPARSELt", CONV_TYPE, API_SPECIAL)), + ("AT_CUSPARSELT_ENABLED", ("AT_HIPSPARSELT_ENABLED", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_SPARSITY_50_PERCENT", ("HIPSPARSELT_SPARSITY_50_PERCENT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseComputeType", ("hipsparseLtComputetype_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32F", ("HIPSPARSELT_COMPUTE_32F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_16F", ("HIPSPARSELT_COMPUTE_16F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32I", ("HIPSPARSELT_COMPUTE_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_TF32", ("HIPSPARSELT_COMPUTE_TF32", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_BIAS_POINTER", ("HIPSPARSELT_MATMUL_BIAS_POINTER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_DEFAULT", ("HIPSPARSELT_MATMUL_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", ("HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseLtHandle_t", ("hipsparseLtHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatDescriptor_t", ("hipsparseLtMatDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtInit", ("hipsparseLtInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompressedSize2", ("hipsparseLtSpMMACompressedSize2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompress2", ("hipsparseLtSpMMACompress2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptor_t", ("hipsparseLtMatmulDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulPlan_t", ("hipsparseLtMatmulPlan_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulAlgSelection_t", ("hipsparseLtMatmulAlgSelection_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtDenseDescriptorInit", ("hipsparseLtDenseDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptorInit", ("hipsparseLtMatmulDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescSetAttribute", ("hipsparseLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSelectionInit", ("hipsparseLtMatmulAlgSelectionInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSetAttribute", ("hipsparseLtMatmulAlgSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanInit", ("hipsparseLtMatmulPlanInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulGetWorkspace", ("hipsparseLtMatmulGetWorkspace", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulSearch", ("hipsparseLtMatmulSearch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgGetAttribute", ("hipsparseLtMatmulAlgGetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmul", ("hipsparseLtMatmul", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatDescriptorDestroy", ("hipsparseLtMatDescriptorDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanDestroy", ("hipsparseLtMatmulPlanDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseGetErrorString", ("hipsparseGetErrorString", CONV_MATH_FUNC, API_SPECIAL)), + # SOLVER + ("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)), + ("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ( + "CUBLAS_OP_T", + ("HIPSOLVER_OP_T", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUBLAS_OP_C", + ("HIPSOLVER_OP_C", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cublasFillMode_t", ("hipsolverFillMode_t", CONV_TYPE, API_SPECIAL)), + ( + "CUBLAS_FILL_MODE_LOWER", + ("HIPSOLVER_FILL_MODE_LOWER", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ( + "CUBLAS_FILL_MODE_UPPER", + ("HIPSOLVER_FILL_MODE_UPPER", CONV_NUMERIC_LITERAL, API_SPECIAL), + ), + ("cublasSideMode_t", ("hipsolverSideMode_t", CONV_TYPE, API_SPECIAL)), + ("CUBLAS_SIDE_LEFT", ("HIPSOLVER_SIDE_LEFT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUBLAS_SIDE_RIGHT", ("HIPSOLVER_SIDE_RIGHT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + + ("cusolverEigMode_t", ("hipsolverEigMode_t", CONV_TYPE, API_SPECIAL)), + ("CUSOLVER_EIG_MODE_VECTOR", ("HIPSOLVER_EIG_MODE_VECTOR", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSOLVER_EIG_MODE_NOVECTOR", ("HIPSOLVER_EIG_MODE_NOVECTOR", CONV_NUMERIC_LITERAL, API_SPECIAL)), + + ("syevjInfo_t", ("hipsolverSyevjInfo_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreateSyevjInfo", ("hipsolverDnCreateSyevjInfo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnXsyevjSetSortEig", ("hipsolverDnXsyevjSetSortEig", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroySyevjInfo", ("hipsolverDnDestroySyevjInfo", CONV_MATH_FUNC, API_SPECIAL)), + + ("gesvdjInfo_t", ("hipsolverGesvdjInfo_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreateGesvdjInfo", ("hipsolverDnCreateGesvdjInfo", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnXgesvdjSetSortEig", ("hipsolverDnXgesvdjSetSortEig", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroyGesvdjInfo", ("hipsolverDnDestroyGesvdjInfo", CONV_MATH_FUNC, API_SPECIAL)), + + ("cusolverDnHandle_t", ("hipsolverDnHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusolverDnCreate", ("hipsolverDnCreate", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnSetStream", ("hipsolverDnSetStream", CONV_MATH_FUNC, API_SPECIAL)), + ("cusolverDnDestroy", ("hipsolverDnDestroy", CONV_MATH_FUNC, API_SPECIAL)), + + # from aten/src/ATen/native/hip/linalg/HIPSolver.cpp + ('cusolverDnParams_t', ('hipsolverDnParams_t', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgeqrf', ('hipsolverDnCgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgeqrf_bufferSize', ('hipsolverDnCgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvd', ('hipsolverDnCgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvd_bufferSize', ('hipsolverDnCgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdj', ('hipsolverDnCgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdjBatched', ('hipsolverDnCgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdjBatched_bufferSize', ('hipsolverDnCgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdj_bufferSize', ('hipsolverDnCgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrf', ('hipsolverDnCgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrf_bufferSize', ('hipsolverDnCgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgetrs', ('hipsolverDnCgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevd', ('hipsolverDnCheevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevd_bufferSize', ('hipsolverDnCheevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevj', ('hipsolverDnCheevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevjBatched', ('hipsolverDnCheevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevjBatched_bufferSize', ('hipsolverDnCheevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCheevj_bufferSize', ('hipsolverDnCheevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrf', ('hipsolverDnCpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrfBatched', ('hipsolverDnCpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrf_bufferSize', ('hipsolverDnCpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrs', ('hipsolverDnCpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCpotrsBatched', ('hipsolverDnCpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCungqr', ('hipsolverDnCungqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCungqr_bufferSize', ('hipsolverDnCungqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCunmqr', ('hipsolverDnCunmqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCunmqr_bufferSize', ('hipsolverDnCunmqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgeqrf', ('hipsolverDnDgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgeqrf_bufferSize', ('hipsolverDnDgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvd', ('hipsolverDnDgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvd_bufferSize', ('hipsolverDnDgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdj', ('hipsolverDnDgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdjBatched', ('hipsolverDnDgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdjBatched_bufferSize', ('hipsolverDnDgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdj_bufferSize', ('hipsolverDnDgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrf', ('hipsolverDnDgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrf_bufferSize', ('hipsolverDnDgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgetrs', ('hipsolverDnDgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDorgqr', ('hipsolverDnDorgqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDorgqr_bufferSize', ('hipsolverDnDorgqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDormqr', ('hipsolverDnDormqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDormqr_bufferSize', ('hipsolverDnDormqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrf', ('hipsolverDnDpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrfBatched', ('hipsolverDnDpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrf_bufferSize', ('hipsolverDnDpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrs', ('hipsolverDnDpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDpotrsBatched', ('hipsolverDnDpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevd', ('hipsolverDnDsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevd_bufferSize', ('hipsolverDnDsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevj', ('hipsolverDnDsyevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevjBatched', ('hipsolverDnDsyevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevjBatched_bufferSize', ('hipsolverDnDsyevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsyevj_bufferSize', ('hipsolverDnDsyevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgeqrf', ('hipsolverDnSgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgeqrf_bufferSize', ('hipsolverDnSgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvd', ('hipsolverDnSgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvd_bufferSize', ('hipsolverDnSgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdj', ('hipsolverDnSgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdjBatched', ('hipsolverDnSgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdjBatched_bufferSize', ('hipsolverDnSgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgesvdj_bufferSize', ('hipsolverDnSgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrf', ('hipsolverDnSgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrf_bufferSize', ('hipsolverDnSgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSgetrs', ('hipsolverDnSgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSorgqr', ('hipsolverDnSorgqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSorgqr_bufferSize', ('hipsolverDnSorgqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSormqr', ('hipsolverDnSormqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSormqr_bufferSize', ('hipsolverDnSormqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrf', ('hipsolverDnSpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrfBatched', ('hipsolverDnSpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrf_bufferSize', ('hipsolverDnSpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrs', ('hipsolverDnSpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSpotrsBatched', ('hipsolverDnSpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevd', ('hipsolverDnSsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevd_bufferSize', ('hipsolverDnSsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevj', ('hipsolverDnSsyevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevjBatched', ('hipsolverDnSsyevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevjBatched_bufferSize', ('hipsolverDnSsyevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsyevj_bufferSize', ('hipsolverDnSsyevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgeqrf', ('hipsolverDnXgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgeqrf_bufferSize', ('hipsolverDnXgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrf', ('hipsolverDnXpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrf_bufferSize', ('hipsolverDnXpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXpotrs', ('hipsolverDnXpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXsyevd', ('hipsolverDnXsyevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXsyevd_bufferSize', ('hipsolverDnXsyevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgeqrf', ('hipsolverDnZgeqrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgeqrf_bufferSize', ('hipsolverDnZgeqrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvd', ('hipsolverDnZgesvd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvd_bufferSize', ('hipsolverDnZgesvd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdj', ('hipsolverDnZgesvdj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdjBatched', ('hipsolverDnZgesvdjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdjBatched_bufferSize', ('hipsolverDnZgesvdjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdj_bufferSize', ('hipsolverDnZgesvdj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrf', ('hipsolverDnZgetrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrf_bufferSize', ('hipsolverDnZgetrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgetrs', ('hipsolverDnZgetrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevd', ('hipsolverDnZheevd', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevd_bufferSize', ('hipsolverDnZheevd_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevj', ('hipsolverDnZheevj', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevjBatched', ('hipsolverDnZheevjBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevjBatched_bufferSize', ('hipsolverDnZheevjBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZheevj_bufferSize', ('hipsolverDnZheevj_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrf', ('hipsolverDnZpotrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrfBatched', ('hipsolverDnZpotrfBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrf_bufferSize', ('hipsolverDnZpotrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrs', ('hipsolverDnZpotrs', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZpotrsBatched', ('hipsolverDnZpotrsBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZungqr', ('hipsolverDnZungqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZungqr_bufferSize', ('hipsolverDnZungqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZunmqr', ('hipsolverDnZunmqr', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZunmqr_bufferSize', ('hipsolverDnZunmqr_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + + # sytrf + ('cusolverDnDsytrf_bufferSize', ('hipsolverDnDsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsytrf_bufferSize', ('hipsolverDnSsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZsytrf_bufferSize', ('hipsolverDnZsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCsytrf_bufferSize', ('hipsolverDnCsytrf_bufferSize', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDsytrf', ('hipsolverDnDsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnSsytrf', ('hipsolverDnSsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZsytrf', ('hipsolverDnZsytrf', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCsytrf', ('hipsolverDnCsytrf', CONV_MATH_FUNC, API_SPECIAL)), + + # gesdva strided + ( + 'cusolverDnSgesvdaStridedBatched_bufferSize', + ('hipsolverDnSgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnDgesvdaStridedBatched_bufferSize', + ('hipsolverDnDgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnCgesvdaStridedBatched_bufferSize', + ('hipsolverDnCgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ( + 'cusolverDnZgesvdaStridedBatched_bufferSize', + ('hipsolverDnZgesvdaStridedBatched_bufferSize', CONV_MATH_FUNC, API_SPECIAL) + ), + ('cusolverDnSgesvdaStridedBatched', ('hipsolverDnSgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnDgesvdaStridedBatched', ('hipsolverDnDgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnCgesvdaStridedBatched', ('hipsolverDnCgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnZgesvdaStridedBatched', ('hipsolverDnZgesvdaStridedBatched', CONV_MATH_FUNC, API_SPECIAL)), + + # gesvdj SetXXX + ('cusolverDnXgesvdjSetTolerance', ('hipsolverDnXgesvdjSetTolerance', CONV_MATH_FUNC, API_SPECIAL)), + ('cusolverDnXgesvdjSetMaxSweeps', ('hipsolverDnXgesvdjSetMaxSweeps', CONV_MATH_FUNC, API_SPECIAL)), + ] +) + +PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("USE_CUDA", ("USE_ROCM", API_PYTORCH)), + ("TORCH_CUDA_CPP_API", ("TORCH_HIP_CPP_API", API_PYTORCH)), + ("TORCH_CUDA_CU_API", ("TORCH_HIP_API", API_PYTORCH)), + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), + ("cudaHostAllocator", ("hipHostAllocator", API_PYTORCH)), + ("cudaDeviceAllocator", ("hipDeviceAllocator", API_PYTORCH)), + ("define MAX_NUM_BLOCKS 200", ("define MAX_NUM_BLOCKS 64", API_PYTORCH)), + ("cuda::CUDAGuard", ("hip::HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAGuard", ("HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAGuard", + ("hip::OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("OptionalCUDAGuard", ("OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::CUDAStreamGuard", + ("hip::HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("CUDAStreamGuard", ("HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAStreamGuard", + ("hip::OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "OptionalCUDAStreamGuard", + ("OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::CUDAMultiStreamGuard", + ("hip::HIPMultiStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "CUDAMultiStreamGuard", + ("HIPMultiStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + # Only get needs to be transformed this way; all the other ones can go + # straight to the normal versions hip::HIPCachingAllocator + ( + "cuda::CUDACachingAllocator::get", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "CUDACachingAllocator::get", + ("HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::recordStream", + ( + "hip::HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "CUDACachingAllocator::recordStream", + ( + "HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "cuda::CUDAAllocator::recordStream", + ( + "hip::HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "CUDAAllocator::recordStream", + ( + "HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getStreamFromPool", + ("hip::getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH), + ), + ("getStreamFromPool", ("getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getDefaultCUDAStream", + ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::getStreamFromExternal", + ("hip::getStreamFromExternalMasqueradingAsCUDA", API_PYTORCH), + ), + ("getStreamFromExternal", ("getStreamFromExternalMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getDefaultCUDAStream", + ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getDefaultCUDAStream", + ("getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::getCurrentCUDAStream", + ("hip::getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getCurrentCUDAStream", + ("getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::setCurrentCUDAStream", + ("hip::setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "setCurrentCUDAStream", + ("setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "ATen/cudnn/Handle.h", + ("ATen/miopen/Handle.h", API_PYTORCH), + ), + # TODO: Undo this special-case; see the header for motivation behind this + # hack. It's VERY important this is only applied to PyTorch HIPify. + ( + "c10/cuda/CUDAGuard.h", + ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDACachingAllocator.h", + ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDAStream.h", + ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH), + ), + ("gloo/cuda.h", ("gloo/hip.h", API_PYTORCH)), + ( + "gloo/cuda_allreduce_halving_doubling.h", + ("gloo/hip_allreduce_halving_doubling.h", API_PYTORCH), + ), + ( + "gloo/cuda_allreduce_halving_doubling_pipelined.h", + ("gloo/hip_allreduce_halving_doubling_pipelined.h", API_PYTORCH), + ), + ("gloo/cuda_allreduce_ring.h", ("gloo/hip_allreduce_ring.h", API_PYTORCH)), + ("gloo/cuda_allreduce_ring_chunked.h", ("gloo/hip_allreduce_ring_chunked.h", API_PYTORCH)), + ( + "gloo/cuda_broadcast_one_to_all.h", + ("gloo/hip_broadcast_one_to_all.h", API_PYTORCH), + ), + ( + "gloo::CudaAllreduceHalvingDoublingPipelined", + ("gloo::HipAllreduceHalvingDoublingPipelined", API_PYTORCH), + ), + ( + "gloo::CudaAllreduceRingChunked", + ("gloo::HipAllreduceRingChunked", API_PYTORCH), + ), + ("gloo::CudaBroadcastOneToAll", ("gloo::HipBroadcastOneToAll", API_PYTORCH)), + ("gloo::CudaHostWorkspace", ("gloo::HipHostWorkspace", API_PYTORCH)), + ("gloo::CudaDeviceWorkspace", ("gloo::HipDeviceWorkspace", API_PYTORCH)), + ("CUDNN_RNN_RELU", ("miopenRNNRELU", API_PYTORCH)), + ("CUDNN_RNN_TANH", ("miopenRNNTANH", API_PYTORCH)), + ("CUDNN_LSTM", ("miopenLSTM", API_PYTORCH)), + ("CUDNN_GRU", ("miopenGRU", API_PYTORCH)), + ("cudnnRNNMode_t", ("miopenRNNMode_t", API_PYTORCH)), + ("magma_queue_create_from_cuda", ("magma_queue_create_from_hip", API_PYTORCH)), + ] +) + +CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)), + ("PYTORCH_CUDA_ALLOC_CONF", ("PYTORCH_CUDA_ALLOC_CONF", API_CAFFE2)), + ("cuda_stream", ("hip_stream", API_CAFFE2)), + # if the header is a native hip folder (under hip directory), + # there is no need to add a hip path to it; the trie in hipify script + # takes this mapping order to forbid further replacement + ("/hip/", ("/hip/", API_CAFFE2)), + ("/context_gpu", ("/hip/context_gpu", API_CAFFE2)), + ("/common_gpu", ("/hip/common_gpu", API_CAFFE2)), + ("/cuda_nccl_gpu", ("/hip/hip_nccl_gpu", API_CAFFE2)), + ("/mixed_utils", ("/hip/mixed_utils", API_CAFFE2)), + ("/operator_fallback_gpu", ("/hip/operator_fallback_gpu", API_CAFFE2)), + ( + "/spatial_batch_norm_op_impl", + ("/hip/spatial_batch_norm_op_impl", API_CAFFE2), + ), + ( + "/recurrent_network_executor_gpu", + ("/hip/recurrent_network_executor_gpu", API_CAFFE2), + ), + ( + "/generate_proposals_op_util_nms_gpu", + ("/hip/generate_proposals_op_util_nms_gpu", API_CAFFE2), + ), + ("/max_pool_with_index_gpu", ("/hip/max_pool_with_index_gpu", API_CAFFE2)), + ("/THCCachingAllocator_gpu", ("/hip/THCCachingAllocator_gpu", API_CAFFE2)), + ("/top_k_heap_selection", ("/hip/top_k_heap_selection", API_CAFFE2)), + ("/top_k_radix_selection", ("/hip/top_k_radix_selection", API_CAFFE2)), + ("/GpuAtomics", ("/hip/GpuAtomics", API_CAFFE2)), + ("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)), + ("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)), + ("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)), + ("/math/reduce.cuh", ("/math/hip/reduce.cuh", API_CAFFE2)), + ("/sgd/adagrad_fused_op_gpu.cuh", ("/sgd/hip/adagrad_fused_op_gpu.cuh", API_CAFFE2)), + ("/operators/segment_reduction_op_gpu.cuh", ("/operators/hip/segment_reduction_op_gpu.cuh", API_CAFFE2)), + ("/gather_op.cuh", ("/hip/gather_op.cuh", API_CAFFE2)), + ("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)), + ("REGISTER_CUDA_OPERATOR", ("REGISTER_HIP_OPERATOR", API_CAFFE2)), + ("CUDA_1D_KERNEL_LOOP", ("HIP_1D_KERNEL_LOOP", API_CAFFE2)), + ("CUDAContext", ("HIPContext", API_CAFFE2)), + ("CAFFE_CUDA_NUM_THREADS", ("CAFFE_HIP_NUM_THREADS", API_CAFFE2)), + ("HasCudaGPU", ("HasHipGPU", API_CAFFE2)), + ("__expf", ("expf", API_CAFFE2)), + ("CUBLAS_ENFORCE", ("HIPBLAS_ENFORCE", API_CAFFE2)), + ("CUBLAS_CHECK", ("HIPBLAS_CHECK", API_CAFFE2)), + ("cublas_handle", ("hipblas_handle", API_CAFFE2)), + ("CURAND_ENFORCE", ("HIPRAND_ENFORCE", API_CAFFE2)), + ("CURAND_CHECK", ("HIPRAND_CHECK", API_CAFFE2)), + ("curandGenerateUniform", ("hiprandGenerateUniform", API_CAFFE2)), + ("curand_generator", ("hiprand_generator", API_CAFFE2)), + ("CaffeCudaGetDevice", ("CaffeHipGetDevice", API_CAFFE2)), + # do not rename CUDA_KERNEL_ASSERT, lazyInitCUDA in caffe2 sources + # the ordered dict guarantees this pattern will match first, before "CUDA" + ("CUDA_KERNEL_ASSERT", ("CUDA_KERNEL_ASSERT", API_CAFFE2)), + ("lazyInitCUDA", ("lazyInitCUDA", API_CAFFE2)), + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_CAFFE2)), + ("CUDA", ("HIP", API_CAFFE2)), + ("Cuda", ("Hip", API_CAFFE2)), + ("cuda_", ("hip_", API_CAFFE2)), + ("_cuda", ("_hip", API_CAFFE2)), + ("CUDNN", ("MIOPEN", API_CAFFE2)), + ("CuDNN", ("MIOPEN", API_CAFFE2)), + ("cudnn", ("miopen", API_CAFFE2)), + ("namespace cuda", ("namespace hip", API_CAFFE2)), + ("cuda::CUDAGuard", ("hip::HIPGuard", API_CAFFE2)), + ("cuda::OptionalCUDAGuard", ("hip::OptionalHIPGuard", API_CAFFE2)), + ("cuda::CUDAStreamGuard", ("hip::HIPStreamGuard", API_CAFFE2)), + ("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuard", API_CAFFE2)), + ("c10/cuda/CUDAGuard.h", ("c10/hip/HIPGuard.h", API_CAFFE2)), + ("gloo/cuda", ("gloo/hip", API_CAFFE2)), + ] +) + +# We must tread very carefully here. Blanket conversions like are done +# in CAFFE2_SPECIFIC_MAPPINGS are not presently supported on PyTorch, +# because a regex for CUDA will also match a filename like CUDAGuard.h, +# but the HIPIFY script doesn't presently move the file and so the substitution +# will be invalid. Instead, we specifically list out every identifier +# and file from c10/cuda which may be used externally, and do substitutions this +# way. +# +# NB: if you want a transformation to ONLY apply to the c10/ directory, +# put it as API_CAFFE2 +C10_MAPPINGS = collections.OrderedDict( + [ + ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), + ("CUDA_LAUNCH_BLOCKING=1", ("AMD_SERIALIZE_KERNEL=3", API_C10)), + ("CUDA_LAUNCH_BLOCKING", ("AMD_SERIALIZE_KERNEL", API_C10)), + ("cuda::compat::", ("hip::compat::", API_C10)), + ("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)), + ("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)), + ("c10/cuda/CUDADeviceAssertionHost.h", ("c10/hip/HIPDeviceAssertionHost.h", API_C10)), + ("c10/cuda/CUDAException.h", ("c10/hip/HIPException.h", API_C10)), + ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), + ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), + ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), + ("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)), + ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), + ("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)), + ("c10/cuda/CUDAAllocatorConfig.h", ("c10/hip/HIPAllocatorConfig.h", API_C10)), + ("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)), + ("c10/cuda/impl/CUDATest.h", ("c10/hip/impl/HIPTest.h", API_C10)), + ("c10/cuda/impl/CUDAGuardImpl.h", ("c10/hip/impl/HIPGuardImpl.h", API_C10)), + ( + "c10/cuda/impl/cuda_cmake_macros.h", + ("c10/hip/impl/hip_cmake_macros.h", API_C10), + ), + ("C10_CUDA_CHECK", ("C10_HIP_CHECK", API_C10)), + ("C10_CUDA_CHECK_WARN", ("C10_HIP_CHECK_WARN", API_C10)), + ("C10_CUDA_ERROR_HANDLED", ("C10_HIP_ERROR_HANDLED", API_C10)), + ("C10_CUDA_IGNORE_ERROR", ("C10_HIP_IGNORE_ERROR", API_C10)), + ("C10_CUDA_CLEAR_ERROR", ("C10_HIP_CLEAR_ERROR", API_C10)), + ("c10::cuda", ("c10::hip", API_C10)), + ("cuda::CUDAStream", ("hip::HIPStream", API_C10)), + ("CUDAStream", ("HIPStream", API_C10)), + # This substitution is not permissible, because there's another copy of this + # function in torch/cuda.h + # ("cuda::device_count", ("hip::device_count", API_C10)), + ("cuda::current_device", ("hip::current_device", API_C10)), + ("cuda::set_device", ("hip::set_device", API_C10)), + ("cuda::device_synchronize", ("hip::device_synchronize", API_C10)), + ("cuda::getStreamFromPool", ("hip::getStreamFromPool", API_C10)), + ("getStreamFromPool", ("getStreamFromPool", API_C10)), + ("cuda::getDefaultCUDAStream", ("hip::getDefaultHIPStream", API_C10)), + ("getDefaultCUDAStream", ("getDefaultHIPStream", API_C10)), + ("cuda::getCurrentCUDAStream", ("hip::getCurrentHIPStream", API_C10)), + ("getCurrentCUDAStream", ("getCurrentHIPStream", API_C10)), + ("cuda::get_cuda_check_prefix", ("hip::get_cuda_check_prefix", API_C10)), + ("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStream", API_C10)), + ("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)), + ("cuda::CUDACachingAllocator", ("hip::HIPCachingAllocator", API_C10)), + ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)), + ("cuda::CUDAAllocatorConfig", ("hip::HIPAllocatorConfig", API_C10)), + ("CUDAAllocatorConfig", ("HIPAllocatorConfig", API_C10)), + ("pinned_use_cuda_host_register", ("pinned_use_hip_host_register", API_C10)), + ("c10::cuda::CUDAAllocator", ("c10::hip::HIPAllocator", API_C10)), + ("cuda::CUDAAllocator", ("hip::HIPAllocator", API_C10)), + ("CUDAStreamCaptureModeGuard", ("HIPStreamCaptureModeGuard", API_C10)), + ("cuda::CUDAStreamCaptureModeGuard", ("cuda::HIPStreamCaptureModeGuard", API_C10)), + ("CUDAAllocator", ("HIPAllocator", API_C10)), + ("C10_CUDA_KERNEL_LAUNCH_CHECK", ("C10_HIP_KERNEL_LAUNCH_CHECK", API_C10)), + ("CUDAKernelLaunchRegistry", ("HIPKernelLaunchRegistry", API_C10)), + ("c10::cuda::get_cuda_check_suffix", ("c10::hip::get_hip_check_suffix", API_C10)), + ] +) + +# NB: C10 mappings are more specific than Caffe2 mappings, so run them +# first +CUDA_TO_HIP_MAPPINGS = [ + CUDA_IDENTIFIER_MAP, + CUDA_TYPE_NAME_MAP, + CUDA_INCLUDE_MAP, + CUDA_SPECIAL_MAP, + C10_MAPPINGS, + PYTORCH_SPECIFIC_MAPPINGS, + CAFFE2_SPECIFIC_MAPPINGS, +] diff --git a/.venv/lib/python3.12/site-packages/torch/utils/hipify/hipify_python.py b/.venv/lib/python3.12/site-packages/torch/utils/hipify/hipify_python.py new file mode 100644 index 0000000000000000000000000000000000000000..76b63e70a8ef0dee5747ffcb65cac4771da08312 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/hipify/hipify_python.py @@ -0,0 +1,1176 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +""" The Python Hipify script. +## +# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved. +# 2017-2018 Advanced Micro Devices, Inc. and +# Facebook Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +import argparse +import fnmatch +import re +import shutil +import sys +import os + +from . import constants +from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS +from .cuda_to_hip_mappings import MATH_TRANSPILATIONS + +from typing import Optional +from collections.abc import Iterator +from collections.abc import Mapping, Iterable +from enum import Enum +import functools +import hashlib + +class CurrentState(Enum): + INITIALIZED = 1 + DONE = 2 + +class HipifyResult: + def __init__(self, current_state, hipified_path): + self.current_state = current_state + self.hipified_path = hipified_path + self.status = "" + + def __str__(self): + return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}") + +HipifyFinalResult = dict[str, HipifyResult] +HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n" +HIPIFY_FINAL_RESULT: HipifyFinalResult = {} + +# Hardcode the PyTorch template map +"""This dictionary provides the mapping from PyTorch kernel template types +to their actual types.""" +PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"} + +__all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter', + 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group', + 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared', + 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file', + 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header', + 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify'] + + +class InputError(Exception): + # Exception raised for errors in the input. + + def __init__(self, message): + super().__init__(message) + self.message = message + + def __str__(self): + return f"Input error: {self.message}" + + +def openf(filename, mode): + return open(filename, mode, errors='ignore') + + +# Color coding for printing +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +# To the programmer, the output of hipify most likely are intermediates. +# This class allows users of hipify to ask for a cleanup by running the +# hipify and compilation in a with instantiating this context manager class +# with keep_intermediates=False. +# The main usecase is the cpp_extensions, specifically the load method. +# It is a good idea to keep intermediates (in case of errors or to +# not recompile unchanged files), but in cases where you don't want to +# keep them (e.g. in the CI), this can be used to remove files. +class GeneratedFileCleaner: + """Context Manager to clean up generated files""" + def __init__(self, keep_intermediates=False): + self.keep_intermediates = keep_intermediates + self.files_to_clean = set() + self.dirs_to_clean = [] + + def __enter__(self): + return self + + def open(self, fn, *args, **kwargs): + if not os.path.exists(fn): + self.files_to_clean.add(os.path.abspath(fn)) + return open(fn, *args, **kwargs) + + def makedirs(self, dn, exist_ok=False): + parent, n = os.path.split(dn) + if not n: + parent, n = os.path.split(parent) + if parent and n and not os.path.exists(parent): + self.makedirs(parent, exist_ok=True) + if not os.path.isdir(dn) or not exist_ok: + os.mkdir(dn) + self.dirs_to_clean.append(os.path.abspath(dn)) + + def __exit__(self, type, value, traceback): + if not self.keep_intermediates: + for f in self.files_to_clean: + os.unlink(f) + for d in self.dirs_to_clean[::-1]: + os.rmdir(d) + +# Follow UNIX convention for paths to use '/' instead of '\\' on Windows +def _to_unix_path(path: str) -> str: + return path.replace(os.sep, '/') + +def match_extensions(filename: str, extensions: Iterable) -> bool: + """Helper method to see if filename ends with certain extension""" + return any(filename.endswith(e) for e in extensions) + + +def _fnmatch(filepath, patterns): + return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns) + + +def matched_files_iter( + root_path: str, + includes: Iterable = (), + ignores: Iterable = (), + extensions: Iterable = (), + out_of_place_only: bool = False, + is_pytorch_extension: bool = False) -> Iterator[str]: + + exact_matches = set(includes) + + # This is a very rough heuristic; really, we want to avoid scanning + # any file which is not checked into source control, but this script + # needs to work even if you're in a Git or Hg checkout, so easier to + # just block the biggest time sinks that won't matter in the + # end. + for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True): + rel_dirpath = os.path.relpath(abs_dirpath, root_path) + if rel_dirpath == '.': + # Blah blah blah O(n) blah blah + if ".git" in dirs: + dirs.remove(".git") + if "build" in dirs: + dirs.remove("build") + if "third_party" in dirs: + dirs.remove("third_party") + dirs.append("third_party/nvfuser") + for filename in filenames: + filepath = _to_unix_path(os.path.join(abs_dirpath, filename)) + rel_filepath = _to_unix_path(os.path.join(rel_dirpath, filename)) + # We respect extensions, UNLESS you wrote the entire + # filename verbatim, in which case we always accept it + if ( + _fnmatch(filepath, includes) + and (not _fnmatch(filepath, ignores)) + and (match_extensions(filepath, extensions) or filepath in exact_matches) + ): + if not is_pytorch_extension: # for pytorch extensions, consider all files + if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath): + continue + if out_of_place_only and not is_out_of_place(rel_filepath): + continue + yield filepath + + +def preprocess_file_and_save_result( + output_directory: str, + filepath: str, + all_files: Iterable, + header_include_dirs: Iterable, + stats: dict[str, list], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> None: + fin_path = os.path.abspath(os.path.join(output_directory, filepath)) + hipify_result = HipifyResult(current_state=CurrentState.INITIALIZED, hipified_path=fin_path) + HIPIFY_FINAL_RESULT[fin_path] = hipify_result + result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats, + hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + # Show what happened + if show_progress and "ignored" not in result.status: + print( + fin_path, "->", + result.hipified_path, result.status, flush=True) + + HIPIFY_FINAL_RESULT[fin_path] = result + + +def compute_stats(stats): + unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]} + + # Print the number of unsupported calls + print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}") + + # Print the list of unsupported calls + print(", ".join(unsupported_calls)) + + # Print the number of kernel launches + print(f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}") + + +def add_dim3(kernel_string, cuda_kernel): + '''adds dim3() to the second and third arguments in the kernel launch''' + count = 0 + closure = 0 + kernel_string = kernel_string.replace("<<<", "").replace(">>>", "") + arg_locs: list[dict[str, int]] = [{} for _ in range(2)] + arg_locs[count]['start'] = 0 + for ind, c in enumerate(kernel_string): + if count > 1: + break + if c == "(": + closure += 1 + elif c == ")": + closure -= 1 + if (c == "," or ind == len(kernel_string) - 1) and closure == 0: + arg_locs[count]['end'] = ind + (c != ",") + count += 1 + if count < 2: + arg_locs[count]['start'] = ind + 1 + + first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1] + second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']] + + first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ") + second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ") + + first_arg_dim3 = f"dim3({first_arg_clean})" + second_arg_dim3 = f"dim3({second_arg_clean})" + + first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3) + second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3) + cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3) + return cuda_kernel + + +RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+') + + +def processKernelLaunches(string, stats): + """ Replace the CUDA style Kernel launches with the HIP style kernel launches.""" + # Concat the namespace with the kernel names. (Find cleaner way of doing this later). + string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string) + + def grab_method_and_template(in_kernel): + # The positions for relevant kernel components. + pos = { + "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]}, + "kernel_name": {"start": -1, "end": -1}, + "template": {"start": -1, "end": -1} + } + + # Count for balancing template + count = {"<>": 0} + + # Status for whether we are parsing a certain item. + START = 0 + AT_TEMPLATE = 1 + AFTER_TEMPLATE = 2 + AT_KERNEL_NAME = 3 + + status = START + + # Parse the string character by character + for i in range(pos["kernel_launch"]["start"] - 1, -1, -1): + char = string[i] + + # Handle Templating Arguments + if status in (START, AT_TEMPLATE): + if char == ">": + if status == START: + status = AT_TEMPLATE + pos["template"]["end"] = i + count["<>"] += 1 + + if char == "<": + count["<>"] -= 1 + if count["<>"] == 0 and (status == AT_TEMPLATE): + pos["template"]["start"] = i + status = AFTER_TEMPLATE + + # Handle Kernel Name + if status != AT_TEMPLATE: + if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}: + if status != AT_KERNEL_NAME: + status = AT_KERNEL_NAME + pos["kernel_name"]["end"] = i + + # Case: Kernel name starts the string. + if i == 0: + pos["kernel_name"]["start"] = 0 + + # Finished + return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] + + else: + # Potential ending point if we're already traversing a kernel's name. + if status == AT_KERNEL_NAME: + pos["kernel_name"]["start"] = i + + # Finished + return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])] + + def find_kernel_bounds(string): + """Finds the starting and ending points for all kernel launches in the string.""" + kernel_end = 0 + kernel_positions = [] + + # Continue until we cannot find any more kernels anymore. + while string.find("<<<", kernel_end) != -1: + # Get kernel starting position (starting from the previous ending point) + kernel_start = string.find("<<<", kernel_end) + + # Get kernel ending position (adjust end point past the >>>) + kernel_end = string.find(">>>", kernel_start) + 3 + if kernel_end <= 0: + raise InputError("no kernel end found") + + # Add to list of traversed kernels + kernel_positions.append({"start": kernel_start, "end": kernel_end, + "group": string[kernel_start: kernel_end]}) + + return kernel_positions + + # Replace comments and string literals from the code so that find_kernel_bounds does not + # wrongly capture kernels in comments and string literals. + # This function replaces them with "x" to keep positions. + def mask_comments(string): + in_comment = '' + prev_c = '' + new_string = '' + for c in string: + if in_comment == '': + # Outside comments + if c == '/' and prev_c == '/': + in_comment = '//' + elif c == '*' and prev_c == '/': + in_comment = '/*' + elif c == '"' and prev_c != '\\' and prev_c != "'": + in_comment = '"' + elif in_comment == '//': + # In // xxx + if c == '\r' or c == '\n': + in_comment = '' + elif in_comment == '/*': + # In /* xxx */ + if c == '/' and prev_c == '*': + in_comment = '' + elif in_comment == '"': + # In "" + if c == '"' and prev_c != '\\': + in_comment = '' + prev_c = c + if in_comment == '': + new_string += c + else: + new_string += 'x' + return new_string + + # Grab positional ranges of all kernel launches + get_kernel_positions = list(find_kernel_bounds(mask_comments(string))) + output_string = string + + # Replace each CUDA kernel with a HIP kernel. + for kernel in get_kernel_positions: + # Get kernel components + params = grab_method_and_template(kernel) + + # Find parenthesis after kernel launch + parenthesis = string.find("(", kernel["end"]) + + # Extract cuda kernel + cuda_kernel = string[params[0]["start"]:parenthesis + 1] + kernel_string = string[kernel['start']:kernel['end']] + end_param_index = 0 if params[1]['end'] == -1 else 1 + kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1] + cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel) + # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size) + num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")"))) + + hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace( + ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace( + ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")") + + # Replace cuda kernel with hip kernel + output_string = output_string.replace(cuda_kernel, hip_kernel) + + # Update the statistics + stats["kernel_launches"].append(hip_kernel) + + return output_string + + +def find_closure_group(input_string, start, group): + """Generalization for finding a balancing closure group + + if group = ["(", ")"], then finds the first balanced parentheses. + if group = ["{", "}"], then finds the first balanced bracket. + + Given an input string, a starting position in the input string, and the group type, + find_closure_group returns the positions of group[0] and group[1] as a tuple. + + Example: + >>> find_closure_group("(hi)", 0, ["(", ")"]) + (0, 3) + """ + + inside_parenthesis = False + parens = 0 + pos = start + p_start, p_end = -1, -1 + + while pos < len(input_string): + if input_string[pos] == group[0]: + if inside_parenthesis is False: + inside_parenthesis = True + parens = 1 + p_start = pos + else: + parens += 1 + elif input_string[pos] == group[1] and inside_parenthesis: + parens -= 1 + + if parens == 0: + p_end = pos + return p_start, p_end + + pos += 1 + return None, None + + +def find_bracket_group(input_string, start): + """Finds the first balanced parantheses.""" + return find_closure_group(input_string, start, group=["{", "}"]) + + +def find_parentheses_group(input_string, start): + """Finds the first balanced bracket.""" + return find_closure_group(input_string, start, group=["(", ")"]) + + +RE_ASSERT = re.compile(r"\bassert[ ]*\(") + + +def replace_math_functions(input_string): + """FIXME: Temporarily replace std:: invocations of math functions + with non-std:: versions to prevent linker errors NOTE: This + can lead to correctness issues when running tests, since the + correct version of the math function (exp/expf) might not get + called. Plan is to remove this function once HIP supports + std:: math function calls inside device code + + """ + output_string = input_string + for func in MATH_TRANSPILATIONS: + output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(') + + return output_string + + +RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()") + + +def hip_header_magic(input_string): + """If the file makes kernel builtin calls and does not include the cuda_runtime.h header, + then automatically add an #include to match the "magic" includes provided by NVCC. + TODO: + Update logic to ignore cases where the cuda_runtime.h is included by another file. + """ + + # Copy the input. + output_string = input_string + + # Check if one of the following headers is already included. + headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"] + if any(re.search(fr'#include ("{ext}"|<{ext}>)', output_string) for ext in headers): + return output_string + + # Rough logic to detect if we're inside device code + hasDeviceLogic: int + hasDeviceLogic = "hipLaunchKernelGGL" in output_string + hasDeviceLogic += "__global__" in output_string + hasDeviceLogic += "__shared__" in output_string + hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None + + # If device logic found, provide the necessary header. + if hasDeviceLogic: + output_string = '#include "hip/hip_runtime.h"\n' + input_string + + return output_string + + +RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;") + + +def replace_extern_shared(input_string): + """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead. + https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__ + Example: + "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)" + "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)" + """ + output_string = input_string + output_string = RE_EXTERN_SHARED.sub( + lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})", output_string) + + return output_string + + +def get_hip_file_path(rel_filepath, is_pytorch_extension=False): + """ + Returns the new name of the hipified file + """ + # At the moment, some PyTorch source files are HIPified in place. The predicate + # is_out_of_place tells us if this is the case or not. + assert not os.path.isabs(rel_filepath) + if not is_pytorch_extension and not is_out_of_place(rel_filepath): + return rel_filepath + + dirpath, filename = os.path.split(rel_filepath) + root, ext = os.path.splitext(filename) + + # Here's the plan: + # + # In general, we need to disambiguate the HIPified filename so that + # it gets a different name from the original filename, so + # that we don't overwrite the original file + # + # There's a lot of different naming conventions across PyTorch + # and Caffe2, but the general recipe is to convert occurrences + # of cuda/gpu to hip, and add hip if there are no occurrences + # of cuda/gpu anywhere. + # + # Concretely, we do the following: + # + # - If there is a directory component named "cuda", replace + # it with "hip", AND + # + # - If the file name contains "CUDA", replace it with "HIP", AND + # + # - ALWAYS replace '.cu' with '.hip', because those files + # contain CUDA kernels that needs to be hipified and processed with + # hip compiler + # + # - If we are not hipifying a PyTorch extension, and the parent + # directory name did not change as a result of the above + # transformations, insert "hip" in the file path + # as the direct parent folder of the file + # + # - If we are hipifying a PyTorch extension, and the parent directory + # name as well as the filename (incl. extension) did not change as + # a result of the above transformations, insert "_hip" in the filename + # + # This isn't set in stone; we might adjust this to support other + # naming conventions. + + if ext == '.cu': + ext = '.hip' + + orig_filename = filename + orig_dirpath = dirpath + + dirpath = dirpath.replace('cuda', 'hip') + dirpath = dirpath.replace('CUDA', 'HIP') + dirpath = dirpath.replace('THC', 'THH') + + root = root.replace('cuda', 'hip') + root = root.replace('CUDA', 'HIP') + # Special case to handle caffe2/core/THCCachingAllocator + if dirpath != "caffe2/core": + root = root.replace('THC', 'THH') + + if not is_pytorch_extension and dirpath == orig_dirpath: + dirpath = os.path.join(dirpath, 'hip') + + if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename: + root = root + "_hip" + + return os.path.join(dirpath, root + ext) + + +def is_out_of_place(rel_filepath): + assert not os.path.isabs(rel_filepath) + if rel_filepath.startswith("torch/"): + return False + if rel_filepath.startswith("third_party/nvfuser/"): + return False + if rel_filepath.startswith("tools/autograd/templates/"): + return False + return True + + +# Keep this synchronized with includes/ignores in build_amd.py +def is_pytorch_file(rel_filepath): + assert not os.path.isabs(rel_filepath) + if rel_filepath.startswith("aten/"): + if rel_filepath.startswith("aten/src/ATen/core/"): + return False + return True + if rel_filepath.startswith("torch/"): + return True + if rel_filepath.startswith("third_party/nvfuser/"): + return True + if rel_filepath.startswith("tools/autograd/templates/"): + return True + return False + + +def is_cusparse_file(rel_filepath): + if is_pytorch_file(rel_filepath): + return "sparse" in rel_filepath.lower() + return False + + +def is_special_file(rel_filepath): + if is_pytorch_file(rel_filepath): + if "sparse" in rel_filepath.lower(): + return True + elif "linalg" in rel_filepath.lower(): + if "batchlinearalgebralibblas" in rel_filepath.lower(): + return False # don't use "special" mappings for this specific linalg cublas file + return True + return False + +def is_caffe2_gpu_file(rel_filepath): + assert not os.path.isabs(rel_filepath) + if rel_filepath.startswith("c10/cuda"): + return True + filename = os.path.basename(rel_filepath) + _, ext = os.path.splitext(filename) + return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) + +class TrieNode: + """A Trie node whose children are represented as a directory of char: TrieNode. + A special char '' represents end of word + """ + + def __init__(self): + self.children = {} + +class Trie: + """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern. + The corresponding Regex should match much faster than a simple Regex union.""" + + def __init__(self): + """Initialize the trie with an empty root node.""" + self.root = TrieNode() + self._hash = hashlib.md5(usedforsecurity=False) + self._digest = self._hash.digest() + + def add(self, word): + """Add a word to the Trie. """ + self._hash.update(word.encode()) + self._digest = self._hash.digest() + node = self.root + + for char in word: + node.children.setdefault(char, TrieNode()) + node = node.children[char] + node.children[''] = True # Mark the end of the word + + def dump(self): + """Return the root node of Trie. """ + return self.root + + def quote(self, char): + """ Escape a char for regex. """ + return re.escape(char) + + def search(self, word): + """Search whether word is present in the Trie. + Returns True if yes, else return False""" + node = self.root + for char in word: + if char in node.children: + node = node.children[char] + else: + return False + + # make sure to check the end-of-word marker present + return '' in node.children + + @functools.lru_cache # noqa: B019 + def _pattern(self, root, digest): + """Convert a Trie into a regular expression pattern + + Memoized on the hash digest of the trie, which is built incrementally + during add(). + """ + node = root + + if "" in node.children and len(node.children.keys()) == 1: + return None + + alt = [] # store alternative patterns + cc = [] # store char to char classes + q = 0 # for node representing the end of word + for char in sorted(node.children.keys()): + if isinstance(node.children[char], TrieNode): + try: + recurse = self._pattern(node.children[char], self._digest) + alt.append(self.quote(char) + recurse) + except Exception: + cc.append(self.quote(char)) + else: + q = 1 + cconly = not len(alt) > 0 + + if len(cc) > 0: + if len(cc) == 1: + alt.append(cc[0]) + else: + alt.append('[' + ''.join(cc) + ']') + + if len(alt) == 1: + result = alt[0] + else: + result = "(?:" + "|".join(alt) + ")" + + if q: + if cconly: + result += "?" + else: + result = f"(?:{result})?" + return result + + def pattern(self): + """Export the Trie to a regex pattern.""" + return self._pattern(self.root, self._digest) + + def export_to_regex(self): + """Export the Trie to a regex pattern.""" + return self._pattern(self.root, self._digest) + +CAFFE2_TRIE = Trie() +CAFFE2_MAP = {} +PYTORCH_TRIE = Trie() +PYTORCH_MAP: dict[str, object] = {} + +# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip. +# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance. +# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex. +# In the case of SPARSE, we must use the hip types for complex instead of the roc types, +# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority. +# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place. +# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices. +# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling. +PYTORCH_SPECIAL_MAP = {} + +for mapping in CUDA_TO_HIP_MAPPINGS: + assert isinstance(mapping, Mapping) + for src, value in mapping.items(): + dst = value[0] + meta_data = value[1:] + if constants.API_CAFFE2 not in meta_data: + PYTORCH_TRIE.add(src) + # if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL + # do not overwrite PYTORCH_MAP, store dst separately + if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""): + PYTORCH_SPECIAL_MAP[src] = dst + else: + PYTORCH_MAP[src] = dst + if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data: + CAFFE2_TRIE.add(src) + CAFFE2_MAP[src] = dst +RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex()) +RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)') + +RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"') +RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>') +RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"') +RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh + +""" +Returns a HipifyResult object with the following details: + "hipified_path" : absolute path of hipified source file + "status" : "ok" if hipified file was written out + "skipped" if an identical hipified file already existed or hipified file couldn't be written out + "ignored" if the source file was a hipified file itself or not meant to be hipified + "current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified + CurrentState.DONE if source file is done with hipification process +""" + + +def preprocessor( + output_directory: str, + filepath: str, + all_files: Iterable, + header_include_dirs: Iterable, + stats: dict[str, list], + hip_clang_launch: bool, + is_pytorch_extension: bool, + clean_ctx: GeneratedFileCleaner, + show_progress: bool) -> HipifyResult: + """ Executes the CUDA -> HIP conversion on the specified file. """ + fin_path = os.path.abspath(os.path.join(output_directory, filepath)) + filepath = _to_unix_path(filepath) + hipify_result = HIPIFY_FINAL_RESULT[fin_path] + if filepath not in all_files: + hipify_result.hipified_path = None + hipify_result.status = "[ignored, not to be hipified]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + + rel_filepath = _to_unix_path(os.path.relpath(filepath, output_directory)) + + with open(fin_path, encoding='utf-8') as fin: + if fin.readline() == HIPIFY_C_BREADCRUMB: + hipify_result.hipified_path = None + hipify_result.status = "[ignored, input is hipified output]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + fin.seek(0) + output_source = fin.read() + + orig_output_source = output_source + + # get_hip_file_path needs a relative path to work correctly + fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension))) + if not os.path.exists(os.path.dirname(fout_path)): + clean_ctx.makedirs(os.path.dirname(fout_path)) + + # unsupported_calls statistics reporting is broken atm + def pt_repl(m): + return PYTORCH_MAP[m.group(0)] + + def pt_special_repl(m): + # checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings + return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m)) + + + if is_pytorch_extension: + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) + else: + if is_special_file(rel_filepath): + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source) + elif is_pytorch_file(rel_filepath): + output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source) + else: + def c2_repl(m): + return CAFFE2_MAP[m.group(0)] + output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source) + + # Header rewrites + def mk_repl(templ, include_current_dir=True): + def repl(m): + f = m.group(1) + filename = os.path.basename(f) + if ( + f.startswith(("ATen/cuda", + "ATen/native/cuda", + "ATen/native/nested/cuda", + "ATen/native/quantized/cuda", + "ATen/native/sparse/cuda", + "ATen/native/transformers/cuda", + "THC/")) or + (f.startswith("THC") and not f.startswith("THCP")) + ): + return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension)) + # if filename is one of the files being hipified for this extension + if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)): + header_dir = None + header_filepath = None + # If include_current_dir True, look first in same dir as the including source file + if include_current_dir: + header_dir_to_check = os.path.dirname(fin_path) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If not found, look in include dirs one by one and first match wins + if header_filepath is None: + for header_include_dir in header_include_dirs: + header_dir_to_check = os.path.join(output_directory, header_include_dir) + header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f)) + if os.path.exists(header_path_to_check): + header_dir = header_dir_to_check + header_filepath = header_path_to_check + # If header file not found, keep as is + if header_filepath is None: + return m.group(0) + # Hipify header file first if needed + if header_filepath not in HIPIFY_FINAL_RESULT: + preprocess_file_and_save_result(output_directory, + header_filepath, + all_files, header_include_dirs, stats, hip_clang_launch, + is_pytorch_extension, clean_ctx, show_progress) + elif header_filepath in HIPIFY_FINAL_RESULT: + header_result = HIPIFY_FINAL_RESULT[header_filepath] + if header_result.current_state == CurrentState.INITIALIZED: + # get_hip_file_path needs a relative path to work correctly + header_rel_path = os.path.relpath(header_filepath, output_directory) + header_fout_path = os.path.abspath(os.path.join(output_directory, + get_hip_file_path(header_rel_path, is_pytorch_extension))) + header_result.hipified_path = header_fout_path + HIPIFY_FINAL_RESULT[header_filepath] = header_result + return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None + else header_filepath, header_dir)) + hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path + return templ.format(_to_unix_path(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None + else header_filepath, header_dir))) + + return m.group(0) + return repl + output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source) + output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source) + output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source) + + # CMakeLists.txt rewrites + if filepath.endswith('CMakeLists.txt'): + output_source = output_source.replace('CUDA', 'HIP') + output_source = output_source.replace('THC', 'THH') + output_source = RE_CU_SUFFIX.sub('.hip', output_source) + + # Perform Kernel Launch Replacements + if not hip_clang_launch: + output_source = processKernelLaunches(output_source, stats) + + # Replace std:: with non-std:: versions + if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath: + output_source = replace_math_functions(output_source) + + # Include header if device code is contained. + output_source = hip_header_magic(output_source) + + # Replace the extern __shared__ + # NOTE: No longer needed after transition from hcc to hipclang. + # output_source = replace_extern_shared(output_source) + + # Don't write out identical hipified files for extensions if dirpath has not changed + if ( + is_pytorch_extension + and orig_output_source == output_source + and os.path.dirname(fin_path) == os.path.dirname(fout_path) + ): + hipify_result.hipified_path = fin_path + hipify_result.status = "[skipped, no changes]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + + # Add hipify breadcrumb for C-style files to avoid re-hipification + if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")): + output_source = HIPIFY_C_BREADCRUMB + output_source + + do_write = True + if os.path.exists(fout_path): + with open(fout_path, encoding='utf-8') as fout_old: + do_write = fout_old.read() != output_source + if do_write: + try: + with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout: + fout.write(output_source) + hipify_result.hipified_path = fout_path + hipify_result.status = "[ok]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + except OSError as e: + print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}', + file=sys.stderr) + hipify_result.hipified_path = fin_path + hipify_result.status = "[skipped, no permissions]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + else: + hipify_result.hipified_path = fout_path + hipify_result.status = "[skipped, already hipified]" + hipify_result.current_state = CurrentState.DONE + return hipify_result + +def file_specific_replacement(filepath, search_string, replace_string, strict=False): + with openf(filepath, "r+") as f: + contents = f.read() + if strict: + contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents) + else: + contents = contents.replace(search_string, replace_string) + f.seek(0) + f.write(contents) + f.truncate() + + +def file_add_header(filepath, header): + with openf(filepath, "r+") as f: + contents = f.read() + if header[0] != "<" and header[-1] != ">": + header = f'"{header}"' + contents = (f'#include {header} \n') + contents + f.seek(0) + f.write(contents) + f.truncate() + + +def fix_static_global_kernels(in_txt): + """Static global kernels in HIP results in a compilation error.""" + in_txt = in_txt.replace(" __global__ static", "__global__") + return in_txt + + +RE_INCLUDE = re.compile(r"#include .*\n") + + +def extract_arguments(start, string): + """ Return the list of arguments in the upcoming function parameter closure. + Example: + string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))' + arguments (output): + '[{'start': 1, 'end': 7}, + {'start': 8, 'end': 16}, + {'start': 17, 'end': 19}, + {'start': 20, 'end': 53}]' + """ + + arguments = [] + closures = { + "<": 0, + "(": 0 + } + current_position = start + argument_start_pos = current_position + 1 + + # Search for final parenthesis + while current_position < len(string): + if string[current_position] == "(": + closures["("] += 1 + elif string[current_position] == ")": + closures["("] -= 1 + elif string[current_position] == "<": + closures["<"] += 1 + elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0: + closures["<"] -= 1 + + # Finished all arguments + if closures["("] == 0 and closures["<"] == 0: + # Add final argument + arguments.append({"start": argument_start_pos, "end": current_position}) + break + + # Finished current argument + if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",": + arguments.append({"start": argument_start_pos, "end": current_position}) + argument_start_pos = current_position + 1 + + current_position += 1 + + return arguments + + +def str2bool(v): + """ArgumentParser doesn't support type=bool. Thus, this helper method will convert + from possible string types to True / False.""" + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def hipify( + project_directory: str, + show_detailed: bool = False, + extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"), + header_extensions: Iterable = (".cuh", ".h", ".hpp"), + output_directory: str = "", + header_include_dirs: Iterable = (), + includes: Iterable = ('*',), + extra_files: Iterable = (), + out_of_place_only: bool = False, + ignores: Iterable = (), + show_progress: bool = True, + hip_clang_launch: bool = False, + is_pytorch_extension: bool = False, + hipify_extra_files_only: bool = False, + clean_ctx: Optional[GeneratedFileCleaner] = None +) -> HipifyFinalResult: + if project_directory == "": + project_directory = os.getcwd() + + # Verify the project directory exists. + if not os.path.exists(project_directory): + print("The project folder specified does not exist.") + sys.exit(1) + + # If no output directory, provide a default one. + if not output_directory: + project_directory.rstrip("/") + output_directory = project_directory + "_amd" + + if project_directory != output_directory: + includes = [include.replace(project_directory, output_directory) for include in includes] + ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores] + + # Copy from project directory to output directory if not done already. + if not os.path.exists(output_directory): + shutil.copytree(project_directory, output_directory) + + includes = list(map(_to_unix_path, includes)) + ignores = list(map(_to_unix_path, ignores)) + + all_files = list(matched_files_iter(output_directory, includes=includes, + ignores=ignores, extensions=extensions, + out_of_place_only=out_of_place_only, + is_pytorch_extension=is_pytorch_extension)) + all_files_set = set(all_files) + for f in extra_files: + if not os.path.isabs(f): + f = os.path.join(output_directory, f) + if f not in all_files_set: + all_files.append(f) + + # List all files in header_include_paths to ensure they are hipified + from pathlib import Path + for header_include_dir in header_include_dirs: + if os.path.isabs(header_include_dir): + header_include_dir_path = Path(header_include_dir) + else: + header_include_dir_path = Path(os.path.join(output_directory, header_include_dir)) + all_files.extend( + str(path) for path in header_include_dir_path.rglob('*') if path.is_file() + and _fnmatch(str(path), includes) + and (not _fnmatch(str(path), ignores)) + and match_extensions(path.name, header_extensions) + ) + + if clean_ctx is None: + clean_ctx = GeneratedFileCleaner(keep_intermediates=True) + + # Preprocessing statistics. + stats: dict[str, list] = {"unsupported_calls": [], "kernel_launches": []} + + for filepath in (all_files if not hipify_extra_files_only else extra_files): + preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs, + stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress) + + print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr) + + # Show detailed summary + if show_detailed: + compute_stats(stats) + + return HIPIFY_FINAL_RESULT diff --git a/.venv/lib/python3.12/site-packages/torch/utils/hipify/version.py b/.venv/lib/python3.12/site-packages/torch/utils/hipify/version.py new file mode 100644 index 0000000000000000000000000000000000000000..1f356cc57bfa00a3b251402604c54702fb414c96 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/hipify/version.py @@ -0,0 +1 @@ +__version__ = '1.0.0' diff --git a/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ab8fd9a35e11fa0571f04adf25181c1cd14d42b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__init__.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs +""" +model_dump: a one-stop shop for TorchScript model inspection. + +The goal of this tool is to provide a simple way to extract lots of +useful information from a TorchScript model and make it easy for humans +to consume. It (mostly) replaces zipinfo, common uses of show_pickle, +and various ad-hoc analysis notebooks. + +The tool extracts information from the model and serializes it as JSON. +That JSON can then be rendered by an HTML+JS page, either by +loading the JSON over HTTP or producing a fully self-contained page +with all of the code and data burned-in. +""" + +# Maintainer notes follow. +""" +The implementation strategy has tension between 3 goals: +- Small file size. +- Fully self-contained. +- Easy, modern JS environment. +Using Preact and HTM achieves 1 and 2 with a decent result for 3. +However, the models I tested with result in ~1MB JSON output, +so even using something heavier like full React might be tolerable +if the build process can be worked out. + +One principle I have followed that I think is very beneficial +is to keep the JSON data as close as possible to the model +and do most of the rendering logic on the client. +This makes for easier development (just refresh, usually), +allows for more laziness and dynamism, and lets us add more +views of the same data without bloating the HTML file. + +Currently, this code doesn't actually load the model or even +depend on any part of PyTorch. I don't know if that's an important +feature to maintain, but it's probably worth preserving the ability +to run at least basic analysis on models that cannot be loaded. + +I think the easiest way to develop this code is to cd into model_dump and +run "python -m http.server", then load http://localhost:8000/skeleton.html +in the browser. In another terminal, run +"python -m torch.utils.model_dump --style=json FILE > \ + torch/utils/model_dump/model_info.json" +every time you update the Python code or model. +When you update JS, just refresh. + +Possible improvements: + - Fix various TODO comments in this file and the JS. + - Make the HTML much less janky, especially the auxiliary data panel. + - Make the auxiliary data panel start small, expand when + data is available, and have a button to clear/contract. + - Clean up the JS. There's a lot of copypasta because + I don't really know how to use Preact. + - Make the HTML render and work nicely inside a Jupyter notebook. + - Add the ability for JS to choose the URL to load the JSON based + on the page URL (query or hash). That way we could publish the + inlined skeleton once and have it load various JSON blobs. + - Add a button to expand all expandable sections so ctrl-F works well. + - Add hyperlinking from data to code, and code to code. + - Add hyperlinking from debug info to Diffusion. + - Make small tensor contents available. + - Do something nice for quantized models + (they probably don't work at all right now). +""" + +import argparse +import io +import json +import os +import pickle +import pprint +import re +import sys +import urllib.parse +import zipfile +from pathlib import Path +import warnings + +import torch.utils.show_pickle + + +DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024 + +__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton', + 'burn_in_info', 'get_info_and_burn_skeleton'] + +def get_storage_info(storage): + assert isinstance(storage, torch.utils.show_pickle.FakeObject) + assert storage.module == "pers" + assert storage.name == "obj" + assert storage.state is None + assert isinstance(storage.args, tuple) + assert len(storage.args) == 1 + sa = storage.args[0] + assert isinstance(sa, tuple) + assert len(sa) == 5 + assert sa[0] == "storage" + assert isinstance(sa[1], torch.utils.show_pickle.FakeClass) + assert sa[1].module == "torch" + assert sa[1].name.endswith("Storage") + storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:]) + return storage_info + + +def hierarchical_pickle(data): + if isinstance(data, (bool, int, float, str, type(None))): + return data + if isinstance(data, list): + return [hierarchical_pickle(d) for d in data] + if isinstance(data, tuple): + return { + "__tuple_values__": hierarchical_pickle(list(data)), + } + if isinstance(data, dict): + return { + "__is_dict__": True, + "keys": hierarchical_pickle(list(data.keys())), + "values": hierarchical_pickle(list(data.values())), + } + if isinstance(data, torch.utils.show_pickle.FakeObject): + typename = f"{data.module}.{data.name}" + if ( + typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.')) + ): + assert data.args == () + return { + "__module_type__": typename, + "state": hierarchical_pickle(data.state), + } + if typename == "torch._utils._rebuild_tensor_v2": + assert data.state is None + storage, offset, size, stride, requires_grad, *_ = data.args + storage_info = get_storage_info(storage) + return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]} + if typename == "torch._utils._rebuild_qtensor": + assert data.state is None + storage, offset, size, stride, quantizer, requires_grad, *_ = data.args + storage_info = get_storage_info(storage) + assert isinstance(quantizer, tuple) + assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass) + assert quantizer[0].module == "torch" + if quantizer[0].name == "per_tensor_affine": + assert len(quantizer) == 3 + assert isinstance(quantizer[1], float) + assert isinstance(quantizer[2], int) + quantizer_extra = list(quantizer[1:3]) + else: + quantizer_extra = [] + quantizer_json = [quantizer[0].name] + quantizer_extra + return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]} + if typename == "torch.jit._pickle.restore_type_tag": + assert data.state is None + obj, typ = data.args + assert isinstance(typ, str) + return hierarchical_pickle(obj) + if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename): + assert data.state is None + ls, = data.args + assert isinstance(ls, list) + return hierarchical_pickle(ls) + if typename == "torch.device": + assert data.state is None + name, = data.args + assert isinstance(name, str) + # Just forget that it was a device and return the name. + return name + if typename == "builtin.UnicodeDecodeError": + assert data.state is None + msg, = data.args + assert isinstance(msg, str) + # Hack: Pretend this is a module so we don't need custom serialization. + # Hack: Wrap the message in a tuple so it looks like a nice state object. + # TODO: Undo at least that second hack. We should support string states. + return { + "__module_type__": typename, + "state": hierarchical_pickle((msg,)), + } + raise Exception(f"Can't prepare fake object of type for JS: {typename}") # noqa: TRY002 + raise Exception(f"Can't prepare data of type for JS: {type(data)}") # noqa: TRY002 + + +def get_model_info( + path_or_file, + title=None, + extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT): + """Get JSON-friendly information about a model. + + The result is suitable for being saved as model_info.json, + or passed to burn_in_info. + """ + + if isinstance(path_or_file, os.PathLike): + default_title = os.fspath(path_or_file) + file_size = path_or_file.stat().st_size # type: ignore[attr-defined] + elif isinstance(path_or_file, str): + default_title = path_or_file + file_size = Path(path_or_file).stat().st_size + else: + default_title = "buffer" + path_or_file.seek(0, io.SEEK_END) + file_size = path_or_file.tell() + path_or_file.seek(0) + + title = title or default_title + + with zipfile.ZipFile(path_or_file) as zf: + path_prefix = None + zip_files = [] + for zi in zf.infolist(): + prefix = re.sub("/.*", "", zi.filename) + if path_prefix is None: + path_prefix = prefix + elif prefix != path_prefix: + raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002 + zip_files.append(dict( + filename=zi.filename, + compression=zi.compress_type, + compressed_size=zi.compress_size, + file_size=zi.file_size, + )) + + assert path_prefix is not None + version = zf.read(path_prefix + "/version").decode("utf-8").strip() + + def get_pickle(name): + assert path_prefix is not None + with zf.open(path_prefix + f"/{name}.pkl") as handle: + raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() + return hierarchical_pickle(raw) + + model_data = get_pickle("data") + constants = get_pickle("constants") + + # Intern strings that are likely to be re-used. + # Pickle automatically detects shared structure, + # so re-used strings are stored efficiently. + # However, JSON has no way of representing this, + # so we have to do it manually. + interned_strings : dict[str, int] = {} + + def ist(s): + if s not in interned_strings: + interned_strings[s] = len(interned_strings) + return interned_strings[s] + + code_files = {} + for zi in zf.infolist(): + if not zi.filename.endswith(".py"): + continue + with zf.open(zi) as handle: + raw_code = handle.read() + with zf.open(zi.filename + ".debug_pkl") as handle: + raw_debug = handle.read() + + # Parse debug info and add begin/end markers if not present + # to ensure that we cover the entire source code. + debug_info_t = pickle.loads(raw_debug) + text_table = None + + if (len(debug_info_t) == 3 and + isinstance(debug_info_t[0], str) and + debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'): + _, text_table, content = debug_info_t + + def parse_new_format(line): + # (0, (('', '', 0), 0, 0)) + num, ((text_indexes, fname_idx, offset), start, end), tag = line + text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index] + fname = text_table[fname_idx] # type: ignore[index] + return num, ((text, fname, offset), start, end), tag + + debug_info_t = map(parse_new_format, content) + + debug_info = list(debug_info_t) + if not debug_info: + debug_info.append((0, (('', '', 0), 0, 0))) + if debug_info[-1][0] != len(raw_code): + debug_info.append((len(raw_code), (('', '', 0), 0, 0))) + + code_parts = [] + for di, di_next in zip(debug_info, debug_info[1:]): + start, source_range, *_ = di + end = di_next[0] + assert end > start + source, s_start, s_end = source_range + s_text, s_file, s_line = source + # TODO: Handle this case better. TorchScript ranges are in bytes, + # but JS doesn't really handle byte strings. + # if bytes and chars are not equivalent for this string, + # zero out the ranges so we don't highlight the wrong thing. + if len(s_text) != len(s_text.encode("utf-8")): + s_start = 0 + s_end = 0 + text = raw_code[start:end] + code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end]) + code_files[zi.filename] = code_parts + + extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json") + extra_files_jsons = {} + for zi in zf.infolist(): + if not extra_files_json_pattern.fullmatch(zi.filename): + continue + if zi.file_size > extra_file_size_limit: + continue + with zf.open(zi) as handle: + try: + json_content = json.load(handle) + extra_files_jsons[zi.filename] = json_content + except json.JSONDecodeError: + extra_files_jsons[zi.filename] = "INVALID JSON" + + always_render_pickles = { + "bytecode.pkl", + } + extra_pickles = {} + for zi in zf.infolist(): + if not zi.filename.endswith(".pkl"): + continue + with zf.open(zi) as handle: + # TODO: handle errors here and just ignore the file? + # NOTE: For a lot of these files (like bytecode), + # we could get away with just unpickling, but this should be safer. + obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() + buf = io.StringIO() + pprint.pprint(obj, buf) + contents = buf.getvalue() + # Checked the rendered length instead of the file size + # because pickles with shared structure can explode in size during rendering. + if os.path.basename(zi.filename) not in always_render_pickles and \ + len(contents) > extra_file_size_limit: + continue + extra_pickles[zi.filename] = contents + + return {"model": dict( + title=title, + file_size=file_size, + version=version, + zip_files=zip_files, + interned_strings=list(interned_strings), + code_files=code_files, + model_data=model_data, + constants=constants, + extra_files_jsons=extra_files_jsons, + extra_pickles=extra_pickles, + )} + + +def get_inline_skeleton(): + """Get a fully-inlined skeleton of the frontend. + + The returned HTML page has no external network dependencies for code. + It can load model_info.json over HTTP, or be passed to burn_in_info. + """ + + import importlib.resources + + skeleton = importlib.resources.read_text(__package__, "skeleton.html") + js_code = importlib.resources.read_text(__package__, "code.js") + for js_module in ["preact", "htm"]: + js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") + js_url = "data:application/javascript," + urllib.parse.quote(js_lib) + js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url) + skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code) + return skeleton + + +def burn_in_info(skeleton, info): + """Burn model info into the HTML skeleton. + + The result will render the hard-coded model info and + have no external network dependencies for code or data. + """ + + # Note that Python's json serializer does not escape slashes in strings. + # Since we're inlining this JSON directly into a script tag, a string + # containing "" would end the script prematurely and + # mess up our page. Unconditionally escape fixes that. + return skeleton.replace( + "BURNED_IN_MODEL_INFO = null", + "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/")) + + +def get_info_and_burn_skeleton(path_or_bytesio, **kwargs): + model_info = get_model_info(path_or_bytesio, **kwargs) + skeleton = get_inline_skeleton() + page = burn_in_info(skeleton, model_info) + return page + + +def main(argv, *, stdout=None): + warnings.warn("torch.utils.model_dump is deprecated and will be removed in a future PyTorch release.") + parser = argparse.ArgumentParser() + parser.add_argument("--style", choices=["json", "html"]) + parser.add_argument("--title") + parser.add_argument("model") + args = parser.parse_args(argv[1:]) + + info = get_model_info(args.model, title=args.title) + + output = stdout or sys.stdout + + if args.style == "json": + output.write(json.dumps(info, sort_keys=True) + "\n") + elif args.style == "html": + skeleton = get_inline_skeleton() + page = burn_in_info(skeleton, info) + output.write(page) + else: + raise Exception("Invalid style") # noqa: TRY002 diff --git a/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__main__.py b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4bdac389bb1f270d74efb6c876258d46077110 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/__main__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +import sys +from . import main + +sys.exit(main(sys.argv)) diff --git a/.venv/lib/python3.12/site-packages/torch/utils/model_dump/code.js b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/code.js new file mode 100644 index 0000000000000000000000000000000000000000..173ddfb639d847159ee4fdf46691404bf1bbb7a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/code.js @@ -0,0 +1,689 @@ +import { h, Component, render } from 'https://unpkg.com/preact?module'; +import htm from 'https://unpkg.com/htm?module'; + +const html = htm.bind(h); + +const BURNED_IN_MODEL_INFO = null; + +// https://stackoverflow.com/a/20732091 +function humanFileSize(size) { + if (size == 0) { return "0 B"; } + var i = Math.floor( Math.log(size) / Math.log(1024) ); + return (size / Math.pow(1024, i)).toFixed(2) * 1 + ' ' + ['B', 'kB', 'MB', 'GB', 'TB'][i]; +} + +function caret(down) { + return down ? "\u25BE" : "\u25B8"; +} + +class Blamer { + constructor() { + this.blame_on_click = false; + this.aux_content_pane = null; + } + + setAuxContentPane(pane) { + this.aux_content_pane = pane; + } + + readyBlame() { + this.blame_on_click = true; + } + + maybeBlame(arg) { + if (!this.blame_on_click) { + return; + } + this.blame_on_click = false; + if (!this.aux_content_pane) { + return; + } + this.aux_content_pane.doBlame(arg); + } +} + +let blame = new Blamer(); + +class Hider extends Component { + constructor() { + super(); + this.state = { shown: null }; + } + + componentDidMount() { + this.setState({ shown: this.props.shown === "true" }); + } + + render({name, children}, {shown}) { + let my_caret = html` this.click()} >${caret(shown)}`; + return html`
+

${my_caret} ${name}

+
${shown ? this.props.children : []}
`; + } + + click() { + this.setState({shown: !this.state.shown}); + } +} + +function ModelSizeSection({model: {file_size, zip_files}}) { + let store_size = 0; + let compr_size = 0; + for (const zi of zip_files) { + if (zi.compression === 0) { + // TODO: Maybe check that compressed_size === file_size. + store_size += zi.compressed_size; + } else { + compr_size += zi.compressed_size; + } + } + let zip_overhead = file_size - store_size - compr_size; + // TODO: Better formatting. Right-align this. + return html` + <${Hider} name="Model Size" shown=true> +
.
+      Model size: ${file_size} (${humanFileSize(file_size)})
+      Stored files: ${store_size} (${humanFileSize(store_size)})
+      Compressed files: ${compr_size} (${humanFileSize(compr_size)})
+      Zip overhead: ${zip_overhead} (${humanFileSize(zip_overhead)})
+    
`; +} + +function StructuredDataSection({name, data, shown}) { + return html` + <${Hider} name=${name} shown=${shown}> +
+ <${StructuredData} data=${data} indent="" prefix=""/> +
`; +} + +class StructuredData extends Component { + constructor() { + super(); + this.state = { shown: false }; + + this.INLINE_TYPES = new Set(["boolean", "number", "string"]) + this.IGNORED_STATE_KEYS = new Set(["training", "_is_full_backward_hook"]) + } + + click() { + this.setState({shown: !this.state.shown}); + } + + expando(data) { + if (data === null || this.INLINE_TYPES.has(typeof(data))) { + return false; + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + // TODO: Maybe show simple lists and tuples on one line. + return true; + } + if (data.__tuple_values__) { + // TODO: Maybe show simple lists and tuples on one line. + return true; + } + if (data.__is_dict__) { + // TODO: Maybe show simple (empty?) dicts on one line. + return true; + } + if (data.__module_type__) { + return true; + } + if (data.__tensor_v2__) { + return false; + } + if (data.__qtensor__) { + return false; + } + throw new Error("Can't handle data type.", data); + } + + renderHeadline(data) { + if (data === null) { + return "None"; + } + if (typeof(data) == "boolean") { + const sd = String(data); + return sd.charAt(0).toUpperCase() + sd.slice(1); + } + if (typeof(data) == "number") { + return JSON.stringify(data); + } + if (typeof(data) == "string") { + return JSON.stringify(data); + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + return "list(["; + } + if (data.__tuple_values__) { + return "tuple(("; + } + if (data.__is_dict__) { + return "dict({"; + } + if (data.__module_type__) { + return data.__module_type__ + "()"; + } + if (data.__tensor_v2__) { + const [storage, offset, size, stride, grad] = data.__tensor_v2__; + const [dtype, key, device, numel] = storage; + return this.renderTensor( + "tensor", dtype, key, device, numel, offset, size, stride, grad, []); + } + if (data.__qtensor__) { + const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__; + const [dtype, key, device, numel] = storage; + let extra_parts = []; + if (quantizer[0] == "per_tensor_affine") { + extra_parts.push(`scale=${quantizer[1]}`); + extra_parts.push(`zero_point=${quantizer[2]}`); + } else { + extra_parts.push(`quantizer=${quantizer[0]}`); + } + return this.renderTensor( + "qtensor", dtype, key, device, numel, offset, size, stride, grad, extra_parts); + } + throw new Error("Can't handle data type.", data); + } + + renderTensor( + prefix, + dtype, + storage_key, + device, + storage_numel, + offset, + size, + stride, + grad, + extra_parts) { + let parts = [ + "(" + size.join(",") + ")", + dtype, + ]; + parts.push(...extra_parts); + if (device != "cpu") { + parts.push(device); + } + if (grad) { + parts.push("grad"); + } + // TODO: Check stride and indicate if the tensor is channels-last or non-contiguous + // TODO: Check size, stride, offset, and numel and indicate if + // the tensor doesn't use all data in storage. + // TODO: Maybe show key? + void(offset); + void(stride); + void(storage_key); + void(storage_numel); + return prefix + "(" + parts.join(", ") + ")"; + } + + renderBody(indent, data) { + if (data === null || this.INLINE_TYPES.has(typeof(data))) { + throw "Should not reach here." + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + for (let idx = 0; idx < data.length; idx++) { + // Does it make sense to put explicit index numbers here? + parts.push(html`
<${StructuredData} prefix=${idx + ": "} indent=${new_indent} data=${data[idx]} />`); + } + return parts; + } + if (data.__tuple_values__) { + // Handled the same as lists. + return this.renderBody(indent, data.__tuple_values__); + } + if (data.__is_dict__) { + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + for (let idx = 0; idx < data.keys.length; idx++) { + if (typeof(data.keys[idx]) != "string") { + parts.push(html`
${new_indent}Non-string key`); + } else { + parts.push(html`
<${StructuredData} prefix=${data.keys[idx] + ": "} indent=${new_indent} data=${data.values[idx]} />`); + } + } + return parts; + } + if (data.__module_type__) { + const mstate = data.state; + if (mstate === null || typeof(mstate) != "object") { + throw new Error("Bad module state"); + } + let new_indent = indent + "\u00A0\u00A0"; + let parts = []; + if (mstate.__is_dict__) { + // TODO: Less copy/paste between this and normal dicts. + for (let idx = 0; idx < mstate.keys.length; idx++) { + if (typeof(mstate.keys[idx]) != "string") { + parts.push(html`
${new_indent}Non-string key`); + } else if (this.IGNORED_STATE_KEYS.has(mstate.keys[idx])) { + // Do nothing. + } else { + parts.push(html`
<${StructuredData} prefix=${mstate.keys[idx] + ": "} indent=${new_indent} data=${mstate.values[idx]} />`); + } + } + } else if (mstate.__tuple_values__) { + parts.push(html`
<${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`); + } else if (mstate.__module_type__) { + // We normally wouldn't have the state of a module be another module, + // but we use "modules" to encode special values (like Unicode decode + // errors) that might be valid states. Just go with it. + parts.push(html`
<${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`); + } else { + throw new Error("Bad module state"); + } + return parts; + } + if (data.__tensor_v2__) { + throw "Should not reach here." + } + if (data.__qtensor__) { + throw "Should not reach here." + } + throw new Error("Can't handle data type.", data); + } + + render({data, indent, prefix}, {shown}) { + const exp = this.expando(data) ? html` this.click()} >${caret(shown)} ` : ""; + const headline = this.renderHeadline(data); + const body = shown ? this.renderBody(indent, data) : ""; + return html`${indent}${exp}${prefix}${headline}${body}`; + } +} + +function ZipContentsSection({model: {zip_files}}) { + // TODO: Add human-readable sizes? + // TODO: Add sorting options? + // TODO: Add hierarchical collapsible tree? + return html` + <${Hider} name="Zip Contents" shown=false> + + + + + + + + + + + ${zip_files.map(zf => html` + + + + + `)} + +
ModeSizeCompressedName
${{0: "store", 8: "deflate"}[zf.compression] || zf.compression}${zf.file_size}${zf.compressed_size}${zf.filename}
`; +} + +function CodeSection({model: {code_files}}) { + return html` + <${Hider} name="Code" shown=false> +
+ ${Object.entries(code_files).map(([fn, code]) => html`<${OneCodeSection} + filename=${fn} code=${code} />`)} +
`; +} + +class OneCodeSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, code}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${code.map(c => this.renderBlock(c))}
+ `; + } + + renderBlock([text, ist_file, line, ist_s_text, s_start, s_end]) { + return html` blame.maybeBlame({ist_file, line, ist_s_text, s_start, s_end})} + >${text}`; + } +} + +function ExtraJsonSection({files}) { + return html` + <${Hider} name="Extra files (JSON)" shown=false> +
+

Use "Log Raw Model Info" for hierarchical view in browser console.

+ ${Object.entries(files).map(([fn, json]) => html`<${OneJsonSection} + filename=${fn} json=${json} />`)} +
`; +} + +class OneJsonSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, json}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${JSON.stringify(json, null, 2)}
+ `; + } +} + +function ExtraPicklesSection({files}) { + return html` + <${Hider} name="Extra Pickles" shown=false> +
+ ${Object.entries(files).map(([fn, content]) => html`<${OnePickleSection} + filename=${fn} content=${content} />`)} +
`; +} + +class OnePickleSection extends Component { + constructor() { + super(); + this.state = { shown: false }; + } + + click() { + const shown = !this.state.shown; + this.setState({shown: shown}); + } + + render({filename, content}, {shown}) { + const header = html` +

+ this.click()} >${caret(shown)} + ${filename}

+ `; + if (!shown) { + return header; + } + return html` + ${header} +
${content}
+ `; + } +} + +function assertStorageAreEqual(key, lhs, rhs) { + if (lhs.length !== rhs.length || + !lhs.every((val, idx) => val === rhs[idx])) { + throw new Error("Storage mismatch for key '" + key + "'"); + } +} + +function computeTensorMemory(numel, dtype) { + const sizes = { + "Byte": 1, + "Char": 1, + "Short": 2, + "Int": 4, + "Long": 8, + "Half": 2, + "Float": 4, + "Double": 8, + "ComplexHalf": 4, + "ComplexFloat": 8, + "ComplexDouble": 16, + "Bool": 1, + "QInt8": 1, + "QUInt8": 1, + "QInt32": 4, + "BFloat16": 2, + }; + let dtsize = sizes[dtype]; + if (!dtsize) { + throw new Error("Unrecognized dtype: " + dtype); + } + return numel * dtsize; +} + +// TODO: Maybe track by dtype as well. +// TODO: Maybe distinguish between visible size and storage size. +function getTensorStorages(data) { + if (data === null) { + return new Map(); + } + if (typeof(data) == "boolean") { + return new Map(); + } + if (typeof(data) == "number") { + return new Map(); + } + if (typeof(data) == "string") { + return new Map(); + } + if (typeof(data) != "object") { + throw new Error("Not an object"); + } + if (Array.isArray(data)) { + let result = new Map(); + for (const item of data) { + const tensors = getTensorStorages(item); + for (const [key, storage] of tensors.entries()) { + if (!result.has(key)) { + result.set(key, storage); + } else { + const old_storage = result.get(key); + assertStorageAreEqual(key, old_storage, storage); + } + } + } + return result; + } + if (data.__tuple_values__) { + return getTensorStorages(data.__tuple_values__); + } + if (data.__is_dict__) { + return getTensorStorages(data.values); + } + if (data.__module_type__) { + return getTensorStorages(data.state); + } + if (data.__tensor_v2__) { + const [storage, offset, size, stride, grad] = data.__tensor_v2__; + const [dtype, key, device, numel] = storage; + return new Map([[key, storage]]); + } + if (data.__qtensor__) { + const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__; + const [dtype, key, device, numel] = storage; + return new Map([[key, storage]]); + } + throw new Error("Can't handle data type.", data); +} + +function getTensorMemoryByDevice(pickles) { + let all_tensors = []; + for (const [name, pickle] of pickles) { + const tensors = getTensorStorages(pickle); + all_tensors.push(...tensors.values()); + } + let result = {}; + for (const storage of all_tensors.values()) { + const [dtype, key, device, numel] = storage; + const size = computeTensorMemory(numel, dtype); + result[device] = (result[device] || 0) + size; + } + return result; +} + +// Make this a separate component so it is rendered lazily. +class OpenTensorMemorySection extends Component { + render({model: {model_data, constants}}) { + let sizes = getTensorMemoryByDevice(new Map([ + ["data", model_data], + ["constants", constants], + ])); + return html` + + + + + + + + + + ${Object.entries(sizes).map(([dev, size]) => html` + + + + `)} + +
DeviceBytesHuman
${dev}${size}${humanFileSize(size)}
`; + } +} + +function TensorMemorySection({model}) { + return html` + <${Hider} name="Tensor Memory" shown=false> + <${OpenTensorMemorySection} model=${model} />`; +} + +class AuxContentPane extends Component { + constructor() { + super(); + this.state = { + blame_info: null, + }; + } + + doBlame(arg) { + this.setState({...this.state, blame_info: arg}); + } + + render({model: {interned_strings}}, {blame_info}) { + let blame_content = ""; + if (blame_info) { + const {ist_file, line, ist_s_text, s_start, s_end} = blame_info; + let s_text = interned_strings[ist_s_text]; + if (s_start != 0 || s_end != s_text.length) { + let prefix = s_text.slice(0, s_start); + let main = s_text.slice(s_start, s_end); + let suffix = s_text.slice(s_end); + s_text = html`${prefix}${main}${suffix}`; + } + blame_content = html` +

${interned_strings[ist_file]}:${line}

+
${s_start}:${s_end}
+
${s_text}

+ `; + } + return html` + +
+ ${blame_content} + `; + } +} + +class App extends Component { + constructor() { + super(); + this.state = { + err: false, + model: null, + }; + } + + componentDidMount() { + const app = this; + if (BURNED_IN_MODEL_INFO !== null) { + app.setState({model: BURNED_IN_MODEL_INFO}); + } else { + fetch("./model_info.json").then(function(response) { + if (!response.ok) { + throw new Error("Response not ok."); + } + return response.json(); + }).then(function(body) { + app.setState({model: body}); + }).catch(function(error) { + console.log("Top-level error: ", error); + }); + } + } + + componentDidCatch(error) { + void(error); + this.setState({...this.state, err: true}); + } + + render(_, {err}) { + if (this.state.model === null) { + return html`

Loading...

`; + } + + const model = this.state.model.model; + + let error_msg = ""; + if (err) { + error_msg = html`

An error occurred. Check console

`; + } + + return html` + ${error_msg} +
+

TorchScript Model (version ${model.version}): ${model.title}

+ + <${ModelSizeSection} model=${model}/> + <${StructuredDataSection} name="Model Data" data=${model.model_data} shown=true/> + <${StructuredDataSection} name="Constants" data=${model.constants} shown=false/> + <${ZipContentsSection} model=${model}/> + <${CodeSection} model=${model}/> + <${ExtraJsonSection} files=${model.extra_files_jsons}/> + <${ExtraPicklesSection} files=${model.extra_pickles}/> + <${TensorMemorySection} model=${model}/> +
+
+ <${AuxContentPane} + err=${this.state.error} + model=${model} + ref=${(p) => blame.setAuxContentPane(p)}/> +
+ `; + } +} + +render(h(App), document.body); diff --git a/.venv/lib/python3.12/site-packages/torch/utils/model_dump/htm.mjs b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/htm.mjs new file mode 100644 index 0000000000000000000000000000000000000000..06f25a13d8021ff4f43de442bbf0279f24735d6c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/htm.mjs @@ -0,0 +1,2 @@ +// HTM, Apache License +var n=function(t,s,r,e){var u;s[0]=0;for(var h=1;h=5&&((e||!n&&5===r)&&(h.push(r,0,e,s),r=6),n&&(h.push(r,n,0,s),r=6)),e=""},a=0;a"===t?(r=1,e=""):e=t+e[0]:u?t===u?u="":e+=t:'"'===t||"'"===t?u=t:">"===t?(p(),r=1):r&&("="===t?(r=5,s=e,e=""):"/"===t&&(r<5||">"===n[a][l+1])?(p(),3===r&&(h=h[0]),r=h,(h=h[0]).push(2,0,r),r=0):" "===t||"\t"===t||"\n"===t||"\r"===t?(p(),r=2):e+=t),3===r&&"!--"===e&&(r=4,h=h[0])}return p(),h}(s)),r),arguments,[])).length>1?r:r[0]} diff --git a/.venv/lib/python3.12/site-packages/torch/utils/model_dump/preact.mjs b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/preact.mjs new file mode 100644 index 0000000000000000000000000000000000000000..8c85bd948c6772ca8d40fc8d6fab6a220d55a1ef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/model_dump/preact.mjs @@ -0,0 +1,2 @@ +// Preact, MIT License +var n,l,u,i,t,o,r={},f=[],e=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i;function c(e,n){for(var t in n)e[t]=n[t];return e}function s(e){var n=e.parentNode;n&&n.removeChild(e)}function a(e,n,t){var _,l,o,r=arguments,i={};for(o in n)"key"==o?_=n[o]:"ref"==o?l=n[o]:i[o]=n[o];if(arguments.length>3)for(t=[t],o=3;o0?v(m.type,m.props,m.key,null,m.__v):m)){if(m.__=t,m.__b=t.__b+1,null===(h=P[p])||h&&m.key==h.key&&m.type===h.type)P[p]=void 0;else for(a=0;a3)for(t=[t],o=3;o + + + TorchScript Model + + + + + + + + diff --git a/.venv/lib/python3.12/site-packages/torch/utils/serialization/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/serialization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d63bc18b69b138a026622de599aed656cc868c8e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/serialization/__init__.py @@ -0,0 +1 @@ +from . import config diff --git a/.venv/lib/python3.12/site-packages/torch/utils/serialization/config.py b/.venv/lib/python3.12/site-packages/torch/utils/serialization/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3fba9f5b82f88f362cd8361656d7820e0216a1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/serialization/config.py @@ -0,0 +1,25 @@ +import sys +from typing import Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING + + +if _TYPE_CHECKING: + from torch.serialization import LoadEndianness as _LoadEndianess + +from torch.utils._config_module import install_config_module as _install_config_module + + +class load: + mmap: bool = False + endianness: _Optional["_LoadEndianess"] = None + # MAP_PRIVATE = 2 + mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2 + calculate_storage_offsets: bool = False + + +class save: + compute_crc32: bool = True + use_pinned_memory_for_d2h: bool = False + storage_alignment: int = 64 + + +_install_config_module(sys.modules[__name__]) diff --git a/.venv/lib/python3.12/site-packages/torch/utils/viz/__init__.py b/.venv/lib/python3.12/site-packages/torch/utils/viz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.12/site-packages/torch/utils/viz/_cycles.py b/.venv/lib/python3.12/site-packages/torch/utils/viz/_cycles.py new file mode 100644 index 0000000000000000000000000000000000000000..455810310817346fc747ed3488a7284695880c83 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/torch/utils/viz/_cycles.py @@ -0,0 +1,499 @@ +# mypy: allow-untyped-defs +import gc +import sys +from typing import Any, NamedTuple, Optional +import types +import weakref +import json +from tempfile import NamedTemporaryFile +import torch +from torch.cuda._memory_viz import _frames_fmt, _block_extra +import atexit +import logging +logger = logging.getLogger(__name__) + +def observe_garbage(observer): + enabled = True + + def disable(): + # when GC runs during exit, things like `sys` will already be unloaded + # so we have to disable the callback to avoid hitting errors. + nonlocal enabled + enabled = False + atexit.register(disable) + + def gc_callback(phase, info): + nonlocal enabled + if not enabled: + return + if phase == "start": + gc.set_debug(gc.DEBUG_SAVEALL) + elif phase == "stop": + orig_trace = sys.getprofile() + self_return = [False] + + def do_collect(*args, **kwargs): + nonlocal enabled + if not self_return[0]: + self_return[0] = True + else: + sys.setprofile(orig_trace) + enabled = False + try: + # things in gc.garbage have survived a collection + # so to free them we have to collect a generation greater than them + # but that might _also_ free other stuff and we don't want to miss + # that stuff. So we have to now force gc at the highest level here, + # report all of what we found, _then_ we can free it up. + if info['generation'] != 2: + gc.collect() + observer(gc.garbage) + gc.garbage.clear() + # we have to re-run GC to clean up the cycles + # we saved from before. + gc.set_debug(0) + before = torch.cuda.memory_allocated() + gc.collect() + after = torch.cuda.memory_allocated() + if before != after: + logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after) + finally: + enabled = True + if orig_trace is not None: + return orig_trace(*args, **kwargs) + sys.setprofile(do_collect) + + gc.callbacks.append(gc_callback) + + # provide a way to disarm the callback + def remove(): + gc.callbacks.remove(gc_callback) + return remove + +# Function to visualize cycles adapated from refcycle: +# Copyright 2013 Mark Dickinson +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def _get_cell_type(): + def f(x=None): + return lambda: x + return type(f().__closure__[0]) + +CellType = _get_cell_type() + +def annotated_references(obj): + """ + Return known information about references held by the given object. + + Returns a mapping from referents to lists of descriptions. Note that there + may be more than one edge leading to any particular referent; hence the + need for a list. Descriptions are currently strings. + + """ + references: dict[int, list[str]] = {} + + def add_reference(name, obj): + references.setdefault(id(obj), []).append(name) + + def add_attrs(*attrs): + for attr in attrs: + if hasattr(obj, attr): + add_reference(attr, getattr(obj, attr)) + + def add_cell_references(): + try: + add_attrs("cell_contents") + except ValueError: + # if cell_contents is empty, + # accessing it raises ValueError + # in this case there is no object to + # annotate + pass + + def add_function_references(): + add_attrs("__defaults__", + "__closure__", + "__globals__", + "__code__", + "__name__", + "__module__", + "__doc__" + "__qualname__", + "__annotations__", + "__kwdefaults__") + + + def add_sequence_references(): + for position, item in enumerate(obj): + add_reference(f"[{position}]", item) + + def add_dict_references(): + for key, value in obj.items(): + add_reference("key", key) + add_reference(f"[{repr(key)}]", value) + + def add_set_references(): + for elt in obj: + add_reference("element", elt) + + def add_bound_method_references(): + add_attrs("__self__", "__func__", "im_class") + + def add_weakref_references(): + # For subclasses of weakref, we can't reliably distinguish the + # callback (if any) from other attributes. + if type(obj) is weakref.ref: + referents = gc.get_referents(obj) + if len(referents) == 1: + target = referents[0] + add_reference("__callback__", target) + + + def add_frame_references(): + f_locals = obj.f_locals + add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals") + # Some badly-behaved code replaces the f_locals dict with + # something that doesn't support the full dict interface. So we + # only continue with the annotation if f_locals is a Python dict. + if type(f_locals) is dict: + for name, local in obj.f_locals.items(): + add_reference(f"local {name}", local) + + def add_getset_descriptor_references(): + add_attrs("__objclass__", "__name__", "__doc__") + + type_based_references = { + tuple: add_sequence_references, + list: add_sequence_references, + dict: add_dict_references, + set: add_set_references, + frozenset: add_set_references, + types.FunctionType: add_function_references, + types.FrameType: add_frame_references, + CellType: add_cell_references, + types.MethodType: add_bound_method_references, + weakref.ref: add_weakref_references, + types.GetSetDescriptorType: add_getset_descriptor_references, + } + + for type_ in type(obj).__mro__: + if type_ in type_based_references: + type_based_references[type_]() + + add_attrs("__dict__", "__class__") + if isinstance(obj, type): + add_attrs("__mro__") + + return references + +############################################################################### +# Object annotations. + + +BASE_TYPES = (int, float, complex, type(None), str, bytes) +FRAME_FILENAME_LIMIT = 32 + +def object_annotation(obj): + """ + Return a string to be used for Graphviz nodes. + + The string should be short but as informative as possible. + """ + + def format_sequence(obj): + body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj)) + if len(obj) > 8: + body = f'{body}, ...{len(obj) - 8}' + return body + + # For basic types, use the repr. + if isinstance(obj, BASE_TYPES): + return repr(obj) + if type(obj).__name__ == 'function': + return f"function\n{obj.__name__}" + elif isinstance(obj, types.MethodType): + try: + func_name = obj.__func__.__qualname__ + except AttributeError: + func_name = "" + return f"instancemethod\n{func_name}" + elif isinstance(obj, list): + return f"[{format_sequence(obj)}]" + elif isinstance(obj, tuple): + return f"({format_sequence(obj)})" + elif isinstance(obj, dict): + return f"dict[{len(obj)}]" + elif isinstance(obj, types.ModuleType): + return f"module\n{obj.__name__}" + elif isinstance(obj, type): + return f"type\n{obj.__name__}" + elif isinstance(obj, weakref.ref): + referent = obj() + if referent is None: + return "weakref (dead referent)" + else: + return f"weakref to id 0x{id(referent):x}" + elif isinstance(obj, types.FrameType): + filename = obj.f_code.co_filename + if len(filename) > FRAME_FILENAME_LIMIT: + filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):] + return f"frame\n{filename}:{obj.f_lineno}" + else: + return f"object\n{type(obj).__module__}.{type(obj).__name__}" + + + +class Node(NamedTuple): + label: str + context: Optional[str] + root: bool + referrents: list[tuple[str, int]] + +def create_graph(objects, *, context=None, filter=None): + if context is None: + context = cuda_allocation_context() + if filter is None: + filter = is_cuda_tensor + + objects = [obj for obj in objects if not isinstance(obj, weakref.ProxyTypes)] + nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects] + node_referrers: list[list[int]] = [[] for obj in objects] + + id_to_node = {id(obj): i for i, obj in enumerate(objects)} + for obj in objects: + fidx = id_to_node[id(obj)] + f = nodes[fidx] + references = annotated_references(obj) + for referrent in gc.get_referents(obj): + rid = id(referrent) + tidx = id_to_node.get(rid, None) + if tidx is None: + continue + labels = references.get(rid, ["?"]) + node_referrers[tidx].append(fidx) + for label in labels: + f.referrents.append((label, tidx)) + + to_search = [i for i, n in enumerate(nodes) if n.root] + to_keep = set() + while to_search: + idx = to_search.pop() + if idx in to_keep: + continue + to_keep.add(idx) + referrers = node_referrers[idx] + to_search.extend(referrers) + id_to_filtered_id: dict[int, int] = {} + filtered: list[Any] = [] + for i, n in enumerate(nodes): + if i in to_keep: + id_to_filtered_id[i] = len(id_to_filtered_id) + filtered.append(n) + for n in filtered: + n.referrents[:] = [(label, id_to_filtered_id[idx]) + for (label, idx) in n.referrents + if idx in id_to_filtered_id] + return filtered + +def escape(n): + return json.dumps(n) + + +def is_cuda_tensor(obj): + return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor) + +def cuda_allocation_context(): + snapshot = torch.cuda.memory._snapshot() + addr_to_frame = {} + for seg in snapshot['segments']: + addr = seg['address'] + for blk in seg['blocks']: + if blk['state'] == 'active_allocated': + frames, _real_size = _block_extra(blk) + addr_to_frame[addr] = frames + addr += blk['size'] + + def object_context(obj): + if is_cuda_tensor(obj): + addr = obj.untyped_storage().data_ptr() + frames = addr_to_frame.get(addr) + if frames is not None: + return '\n'.join(_frames_fmt(frames, full_filename=True)) + return None + return object_context + +def to_dot(nodes): + lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;'] + for i, n in enumerate(nodes): + lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];') + + for i, f in enumerate(nodes): + for label, j in f.referrents: + lines.append(f'{i} -> {j} [label = {escape(label)}]') + lines.append("}\n") + return '\n'.join(lines) + +_template = """ + + + + + + +
+
+
+
+
Mouse over tensor objects to see where they were allocated.
+
+
+ + + + +""" +_listener_template = """ +document.getElementById('node{id}').addEventListener('mouseover', function(event) {{ + document.getElementById("stacktrace").textContent = {stack} +}}) +""" +def to_html(nodes): + listeners = [] + for i, n in enumerate(nodes): + if n.context is None: + continue + s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}')) + listeners.append(s) + dot = to_dot(nodes) + return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners)) + +def observe_tensor_cycles(callback): + torch.cuda.memory._record_memory_history(max_entries=100000) + + def observer(garbage): + if garbage: + if not any(is_cuda_tensor(obj) for obj in garbage): + logger.info("No CUDA Tensors found in garbage") + return + callback(to_html(create_graph(garbage))) + return observe_garbage(observer) + + +def warn_tensor_cycles(): + """ + Install a warning that reports whenever a cycle that is holding CUDA memory is observed. + + The warning produces an .html file that visualizes the cycle, + and links it to the stack frame that allocted the CUDA tensor. + + Reference cycles are freed by the cycle collector rather than being cleaned up + when the objects in the cycle first become unreachable. If a cycle points to a tensor, + the CUDA memory for that tensor will not be freed until garbage collection runs. + Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as + non-deterministic allocation behavior which is harder to debug. + """ + logger.info("Watching Python reference cycles for CUDA Tensors.") + + def write_and_log(html): + with NamedTemporaryFile('w', suffix='.html', delete=False) as f: + f.write(html) + logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name) + return observe_tensor_cycles(write_and_log)